espnet2.asr_transducer.decoder.blocks.mega.MEGA
espnet2.asr_transducer.decoder.blocks.mega.MEGA
class espnet2.asr_transducer.decoder.blocks.mega.MEGA(size: int = 512, num_heads: int = 4, qk_size: int = 128, v_size: int = 1024, activation: ~torch.nn.modules.module.Module = ReLU(), normalization: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, rel_pos_bias_type: str = 'simple', max_positions: int = 2048, truncation_length: int | None = None, chunk_size: int = -1, dropout_rate: float = 0.0, att_dropout_rate: float = 0.0, ema_dropout_rate: float = 0.0)
Bases: Module
MEGA module.
- Parameters:
- size – Input/Output size.
- num_heads – Number of EMA heads.
- qk_size – Shared query and key size for attention module.
- v_size – Value size for attention module.
- qk_v_size – (QK, V) sizes for attention module.
- activation – Activation function type.
- normalization – Normalization module.
- rel_pos_bias_type – Type of relative position bias in attention module.
- max_positions – Maximum number of position for RelativePositionBias.
- truncation_length – Maximum length for truncation in EMA module.
- chunk_size – Chunk size for attention computation (-1 = full context).
- dropout_rate – Dropout rate for inner modules.
- att_dropout_rate – Dropout rate for the attention module.
- ema_dropout_rate – Dropout rate for the EMA module.
Construct a MEGA object.
forward(x: Tensor, mask: Tensor | None = None, attn_mask: Tensor | None = None, state: Dict[str, Tensor | None] | None = None) → Tuple[Tensor, Dict[str, Tensor | None] | None]
Compute moving average equiped gated attention.
- Parameters:
- x – MEGA input sequences. (L, B, size)
- mask – MEGA input sequence masks. (B, 1, L)
- attn_mask – MEGA attention mask. (1, L, L)
- state – Decoder hidden states.
- Returns: MEGA output sequences. (B, L, size) state: Decoder hidden states.
- Return type: x
reset_parameters(val: int = 0.0, std: int = 0.02) → None
Reset module parameters.
- Parameters:
- val – Initialization value.
- std – Standard deviation.
softmax_attention(query: Tensor, key: Tensor, mask: Tensor | None = None, attn_mask: Tensor | None = None) → Tensor
Compute attention weights with softmax.
- Parameters:
- query – Query tensor. (B, 1, L, D)
- key – Key tensor. (B, 1, L, D)
- mask – Sequence mask. (B, 1, L)
- attn_mask – Attention mask. (1, L, L)
- Returns: Attention weights. (B, 1, L, L)
- Return type: attn_weights