espnet2.st package

espnet2.st.__init__

espnet2.st.espnet_model

class espnet2.st.espnet_model.ESPnetSTModel(vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[espnet2.asr.frontend.abs_frontend.AbsFrontend], specaug: Optional[espnet2.asr.specaug.abs_specaug.AbsSpecAug], normalize: Optional[espnet2.layers.abs_normalize.AbsNormalize], preencoder: Optional[espnet2.asr.preencoder.abs_preencoder.AbsPreEncoder], encoder: espnet2.asr.encoder.abs_encoder.AbsEncoder, hier_encoder: Optional[espnet2.asr.encoder.abs_encoder.AbsEncoder], md_encoder: Optional[espnet2.asr.encoder.abs_encoder.AbsEncoder], extra_mt_encoder: Optional[espnet2.asr.encoder.abs_encoder.AbsEncoder], postencoder: Optional[espnet2.asr.postencoder.abs_postencoder.AbsPostEncoder], decoder: espnet2.asr.decoder.abs_decoder.AbsDecoder, extra_asr_decoder: Optional[espnet2.asr.decoder.abs_decoder.AbsDecoder], extra_mt_decoder: Optional[espnet2.asr.decoder.abs_decoder.AbsDecoder], ctc: Optional[espnet2.asr.ctc.CTC], st_ctc: Optional[espnet2.asr.ctc.CTC], st_joint_network: Optional[torch.nn.modules.module.Module], src_vocab_size: Optional[int], src_token_list: Union[Tuple[str, ...], List[str], None], asr_weight: float = 0.0, mt_weight: float = 0.0, mtlalpha: float = 0.0, st_mtlalpha: float = 0.0, ignore_id: int = -1, tgt_ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', tgt_sym_space: str = '<space>', tgt_sym_blank: str = '<blank>', extract_feats_in_collect_stats: bool = True, ctc_sample_rate: float = 0.0, tgt_sym_sos: str = '<sos/eos>', tgt_sym_eos: str = '<sos/eos>', lang_token_id: int = -1)[source]

Bases: espnet2.train.abs_espnet_model.AbsESPnetModel

CTC-attention hybrid Encoder-Decoder model

collect_feats(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, src_text: Optional[torch.Tensor] = None, src_text_lengths: Optional[torch.Tensor] = None, **kwargs) → Dict[str, torch.Tensor][source]
encode(speech: torch.Tensor, speech_lengths: torch.Tensor, return_int_enc: bool = False) → Tuple[torch.Tensor, torch.Tensor][source]

Frontend + Encoder. Note that this method is used by st_inference.py

Parameters:
  • speech – (Batch, Length, …)

  • speech_lengths – (Batch, )

forward(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, src_text: Optional[torch.Tensor] = None, src_text_lengths: Optional[torch.Tensor] = None, **kwargs) → Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor][source]

Frontend + Encoder + Decoder + Calc loss

Parameters:
  • speech – (Batch, Length, …)

  • speech_lengths – (Batch,)

  • text – (Batch, Length)

  • text_lengths – (Batch,)

  • src_text – (Batch, length)

  • src_text_lengths – (Batch,)

  • kwargs – “utt_id” is among the input.