espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel
espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel
class espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, encoder: Encoder, decoder: AbsDecoder, joint_network: JointNetwork, transducer_weight: float = 1.0, use_k2_pruned_loss: bool = False, k2_pruned_loss_args: Dict = {}, warmup_steps: int = 25000, validation_nstep: int = 2, fastemit_lambda: float = 0.0, auxiliary_ctc_weight: float = 0.0, auxiliary_ctc_dropout_rate: float = 0.0, auxiliary_lm_loss_weight: float = 0.0, auxiliary_lm_loss_smoothing: float = 0.05, ignore_id: int = -1, sym_space: str = '<space>', sym_blank: str = '<blank>', report_cer: bool = False, report_wer: bool = False, extract_feats_in_collect_stats: bool = True)
Bases: AbsESPnetModel
ESPnet2ASRTransducerModel module definition.
- Parameters:
- vocab_size – Size of complete vocabulary (w/ SOS/EOS and blank included).
- token_list – List of tokens in vocabulary (minus reserved tokens).
- frontend – Frontend module.
- specaug – SpecAugment module.
- normalize – Normalization module.
- encoder – Encoder module.
- decoder – Decoder module.
- joint_network – Joint Network module.
- transducer_weight – Weight of the Transducer loss.
- use_k2_pruned_loss – Whether to use k2 pruned Transducer loss.
- k2_pruned_loss_args – Arguments of the k2 loss pruned Transducer loss.
- warmup_steps – Number of steps in warmup, used for pruned loss scaling.
- validation_nstep – Maximum number of symbol expansions at each time step when reporting CER or/and WER using mAES.
- fastemit_lambda – FastEmit lambda value.
- auxiliary_ctc_weight – Weight of auxiliary CTC loss.
- auxiliary_ctc_dropout_rate – Dropout rate for auxiliary CTC loss inputs.
- auxiliary_lm_loss_weight – Weight of auxiliary LM loss.
- auxiliary_lm_loss_smoothing – Smoothing rate for LM loss’ label smoothing.
- ignore_id – Initial padding ID.
- sym_space – Space symbol.
- sym_blank – Blank Symbol.
- report_cer – Whether to report Character Error Rate during validation.
- report_wer – Whether to report Word Error Rate during validation.
- extract_feats_in_collect_stats – Whether to use extract_feats stats collection.
Construct an ESPnetASRTransducerModel object.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features sequences and features lengths sequences.
- Parameters:
- speech – Speech sequences. (B, S)
- speech_lengths – Speech sequences lengths. (B,)
- text – Label ID sequences. (B, L)
- text_lengths – Label ID sequences lengths. (B,)
- kwargs – Contains “utts_id”.
- Returns: “feats”: Features sequences. (B, T, D_feats), : ”feats_lengths”: Features sequences lengths. (B,)
- Return type: {}
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
Encoder speech sequences.
- Parameters:
- speech – Speech sequences. (B, S)
- speech_lengths – Speech sequences lengths. (B,)
- Returns: Encoder outputs. (B, T, D_enc) encoder_out_lens: Encoder outputs lengths. (B,)
- Return type: encoder_out
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Forward architecture and compute loss(es).
- Parameters:
- speech – Speech sequences. (B, S)
- speech_lengths – Speech sequences lengths. (B,)
- text – Label ID sequences. (B, L)
- text_lengths – Label ID sequences lengths. (B,)
- kwargs – Contains “utts_id”.
- Returns: Main loss value. stats: Task statistics. weight: Task weights.
- Return type: loss