espnet2.s2t package

espnet2.s2t.espnet_model

class espnet2.s2t.espnet_model.ESPnetS2TModel(vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[espnet2.asr.frontend.abs_frontend.AbsFrontend], specaug: Optional[espnet2.asr.specaug.abs_specaug.AbsSpecAug], normalize: Optional[espnet2.layers.abs_normalize.AbsNormalize], preencoder: Optional[espnet2.asr.preencoder.abs_preencoder.AbsPreEncoder], encoder: espnet2.asr.encoder.abs_encoder.AbsEncoder, postencoder: Optional[espnet2.asr.postencoder.abs_postencoder.AbsPostEncoder], decoder: Optional[espnet2.asr.decoder.abs_decoder.AbsDecoder], ctc: espnet2.asr.ctc.CTC, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_sos: str = '<sos>', sym_eos: str = '<eos>', sym_sop: str = '<sop>', sym_na: str = '<na>', extract_feats_in_collect_stats: bool = True)[source]

Bases: espnet2.train.abs_espnet_model.AbsESPnetModel

CTC-attention hybrid Encoder-Decoder model

collect_feats(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, text_prev: torch.Tensor, text_prev_lengths: torch.Tensor, text_ctc: torch.Tensor, text_ctc_lengths: torch.Tensor, **kwargs) → Dict[str, torch.Tensor][source]
encode(speech: torch.Tensor, speech_lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]

Frontend + Encoder. Note that this method is used by s2t_inference.py

Parameters:
  • speech – (Batch, Length, …)

  • speech_lengths – (Batch, )

forward(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, text_prev: torch.Tensor, text_prev_lengths: torch.Tensor, text_ctc: torch.Tensor, text_ctc_lengths: torch.Tensor, **kwargs) → Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor][source]

Frontend + Encoder + Decoder + Calc loss

Parameters:
  • speech – (Batch, Length, …)

  • speech_lengths – (Batch, )

  • text – (Batch, Length)

  • text_lengths – (Batch,)

  • text_prev – (Batch, Length)

  • text_prev_lengths – (Batch,)

  • text_ctc – (Batch, Length)

  • text_ctc_lengths – (Batch,)

  • kwargs – “utt_id” is among the input.

espnet2.s2t.__init__