crested.tl.zoo.utils.mha_block_enf

Contents

crested.tl.zoo.utils.mha_block_enf#

crested.tl.zoo.utils.mha_block_enf(inputs, num_heads, key_query_dim, value_dim, scaling=True, attn_dropout=0.05, pos_dropout=0.01, final_dropout=0.4, symmetric_pos_encoding=False, pos_encoding_funs='borzoi', pos_encoding_abs=True, num_pos_feats=None, zero_init=True, residual=True, ln_epsilon=1e-05, name_prefix=None)#

Construct a MHA block (for Enformer/Borzoi), consisting of Residual(LayerNorm+MHSelfAttention+Dropout).

Parameters:
  • inputs – Input tensor.

  • num_heads (int) – Number of attention heads to use.

  • key_query_dim (int) – Number of k (key) and q (query) dimensions in the attention mechanism.

  • value_dim (int) – Number of v (value) dimensions in the attention mechanism.

  • scaling (bool (default: True)) – Whether to use scaling.

  • attn_dropout (float (default: 0.05)) – Attention dropout rate.

  • pos_dropout (float (default: 0.01)) – Positional embedding dropout rate.

  • final_dropout (float (default: 0.4)) – Block-included post-MHA dropout rate.

  • symmetric_pos_encoding (bool (default: False)) – Whether to make positional encodings symmetric. Only relevant if pos_encoding = True.

  • pos_encoding_funs (str (default: 'borzoi')) – Can be ‘enformer’ or ‘borzoi’. Enformer default uses all, using exponential+central_mask_enf+gamma., Borzoi default only uses its version of central_mask.

  • pos_encoding_abs (bool (default: True)) – Whether to use the absolute of values before calculating the relative position encoding.

  • num_pos_feats (int | None (default: None)) – Number of positional features. If not supplied, calculated from value_dim and number of position encoding functions. Min 6 for default relative_position_functions, min 12 for positional_features_sin_cos.

  • zero_init (bool (default: True)) – Whether to initialize MHA from zero.

  • residual (bool (default: True)) – Whether to wrap the entire block in residual structure.

  • ln_epsilon (float (default: 1e-05)) – Epsilon to use in the layer normalisation layer.

  • name_prefix (str | None (default: None)) – Prefix for layer names.

Return type:

KerasTensor

Returns:

Output tensor after applying the MHA block.