espnet2.asr_transducer.decoder.mega_decoder.MEGADecoder
espnet2.asr_transducer.decoder.mega_decoder.MEGADecoder
class espnet2.asr_transducer.decoder.mega_decoder.MEGADecoder(vocab_size: int, block_size: int = 512, linear_size: int = 1024, qk_size: int = 128, v_size: int = 1024, num_heads: int = 4, rel_pos_bias_type: str = 'simple', max_positions: int = 2048, truncation_length: int | None = None, normalization_type: str = 'layer_norm', normalization_args: Dict = {}, activation_type: str = 'swish', activation_args: Dict = {}, chunk_size: int = -1, num_blocks: int = 4, dropout_rate: float = 0.0, embed_dropout_rate: float = 0.0, att_dropout_rate: float = 0.0, ema_dropout_rate: float = 0.0, ffn_dropout_rate: float = 0.0, embed_pad: int = 0)
Bases: AbsDecoder
MEGA decoder module.
Based on https://arxiv.org/pdf/2209.10655.pdf.
- Parameters:
- vocab_size – Vocabulary size.
- block_size – Input/Output size.
- linear_size – NormalizedPositionwiseFeedForward hidden size.
- qk_size – Shared query and key size for attention module.
- v_size – Value size for attention module.
- num_heads – Number of EMA heads.
- rel_pos_bias – Type of relative position bias in attention module.
- max_positions – Maximum number of position for RelativePositionBias.
- truncation_length – Maximum length for truncation in EMA module.
- normalization_type – Normalization layer type.
- normalization_args – Normalization layer arguments.
- activation_type – Activation function type.
- activation_args – Activation function arguments.
- chunk_size – Chunk size for attention computation (-1 = full context).
- num_blocks – Number of MEGA blocks.
- dropout_rate – Dropout rate for MEGA internal modules.
- embed_dropout_rate – Dropout rate for embedding layer.
- att_dropout_rate – Dropout rate for the attention module.
- ema_dropout_rate – Dropout rate for the EMA module.
- ffn_dropout_rate – Dropout rate for the feed-forward module.
- embed_pad – Embedding padding symbol ID.
Construct a MEGADecoder object.
batch_score(hyps: List[Hypothesis]) → Tuple[Tensor, List[Dict[str, Tensor]]]
One-step forward hypotheses.
- Parameters:hyps – Hypotheses.
- Returns: states:
- Return type: out
create_batch_states(new_states: List[List[Dict[str, Tensor]]]) → List[Dict[str, Tensor]]
Create batch of decoder hidden states given a list of new states.
- Parameters:new_states – Decoder hidden states. [B x [N x Dict]]
- Returns: Decoder hidden states. [N x Dict]
forward(labels: Tensor) → Tensor
Encode source label sequences.
- Parameters:labels – Decoder input sequences. (B, L)
- Returns: Decoder output sequences. (B, U, D_dec)
- Return type: out
inference(labels: Tensor, states: List[Dict[str, Tensor]]) → Tuple[Tensor, List[Dict[str, Tensor]]]
Encode source label sequences.
- Parameters:
- labels – Decoder input sequences. (B, L)
- states – Decoder hidden states. [B x Dict]
- Returns: Decoder output sequences. (B, U, D_dec) new_states: Decoder hidden states. [B x Dict]
- Return type: out
init_state(batch_size: int = 0) → List[Dict[str, Tensor]]
Initialize MEGADecoder states.
- Parameters:batch_size – Batch size.
- Returns: Decoder hidden states. [N x Dict]
- Return type: states
score(label_sequence: List[int], states: List[Dict[str, Tensor]]) → Tuple[Tensor, List[Dict[str, Tensor]]]
One-step forward hypothesis.
- Parameters:
- label_sequence – Current label sequence.
- states – Decoder hidden states. (??)
- Returns: Decoder output sequence. (D_dec) states: Decoder hidden states. (??)
select_state(states: List[Dict[str, Tensor]], idx: int) → List[Dict[str, Tensor]]
Select ID state from batch of decoder hidden states.
- Parameters:states – Decoder hidden states. [N x Dict]
- Returns: Decoder hidden states for given ID. [N x Dict]
set_device(device: device) → None
Set GPU device to use.
- Parameters:device – Device ID.
stack_qk_states(state_list: List[Tensor], dim: int) → List[Tensor]
Stack query or key states with different lengths.
- Parameters:state_list – List of query or key states.
- Returns: Query/Key state.
- Return type: new_state