espnet2.st.espnet_model.ESPnetSTModel
espnet2.st.espnet_model.ESPnetSTModel
class espnet2.st.espnet_model.ESPnetSTModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, hier_encoder: AbsEncoder | None, md_encoder: AbsEncoder | None, extra_mt_encoder: AbsEncoder | None, postencoder: AbsPostEncoder | None, decoder: AbsDecoder, extra_asr_decoder: AbsDecoder | None, extra_mt_decoder: AbsDecoder | None, ctc: CTC | None, st_ctc: CTC | None, st_joint_network: Module | None, src_vocab_size: int | None, src_token_list: Tuple[str, ...] | List[str] | None, asr_weight: float = 0.0, mt_weight: float = 0.0, mtlalpha: float = 0.0, st_mtlalpha: float = 0.0, ignore_id: int = -1, tgt_ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', tgt_sym_space: str = '<space>', tgt_sym_blank: str = '<blank>', extract_feats_in_collect_stats: bool = True, ctc_sample_rate: float = 0.0, tgt_sym_sos: str = '<sos/eos>', tgt_sym_eos: str = '<sos/eos>', lang_token_id: int = -1)
Bases: AbsESPnetModel
CTC-attention hybrid Encoder-Decoder model
Initializes internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, src_text: Tensor | None = None, src_text_lengths: Tensor | None = None, **kwargs) → Dict[str, Tensor]
encode(speech: Tensor, speech_lengths: Tensor, return_int_enc: bool = False) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by st_inference.py
- Parameters:
- speech – (Batch, Length, …)
- speech_lengths – (Batch, )
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, src_text: Tensor | None = None, src_text_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech – (Batch, Length, …)
- speech_lengths – (Batch,)
- text – (Batch, Length)
- text_lengths – (Batch,)
- src_text – (Batch, length)
- src_text_lengths – (Batch,)
- kwargs – “utt_id” is among the input.