espnet2.s2t package¶
espnet2.s2t.espnet_model¶
-
class
espnet2.s2t.espnet_model.
ESPnetS2TModel
(vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[espnet2.asr.frontend.abs_frontend.AbsFrontend], specaug: Optional[espnet2.asr.specaug.abs_specaug.AbsSpecAug], normalize: Optional[espnet2.layers.abs_normalize.AbsNormalize], preencoder: Optional[espnet2.asr.preencoder.abs_preencoder.AbsPreEncoder], encoder: espnet2.asr.encoder.abs_encoder.AbsEncoder, postencoder: Optional[espnet2.asr.postencoder.abs_postencoder.AbsPostEncoder], decoder: Optional[espnet2.asr.decoder.abs_decoder.AbsDecoder], ctc: espnet2.asr.ctc.CTC, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_sos: str = '<sos>', sym_eos: str = '<eos>', sym_sop: str = '<sop>', sym_na: str = '<na>', extract_feats_in_collect_stats: bool = True)[source]¶ Bases:
espnet2.train.abs_espnet_model.AbsESPnetModel
CTC-attention hybrid Encoder-Decoder model
-
collect_feats
(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, text_prev: torch.Tensor, text_prev_lengths: torch.Tensor, text_ctc: torch.Tensor, text_ctc_lengths: torch.Tensor, **kwargs) → Dict[str, torch.Tensor][source]¶
-
encode
(speech: torch.Tensor, speech_lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]¶ Frontend + Encoder. Note that this method is used by s2t_inference.py
- Parameters:
speech – (Batch, Length, …)
speech_lengths – (Batch, )
-
forward
(speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, text_prev: torch.Tensor, text_prev_lengths: torch.Tensor, text_ctc: torch.Tensor, text_ctc_lengths: torch.Tensor, **kwargs) → Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor][source]¶ Frontend + Encoder + Decoder + Calc loss
- Parameters:
speech – (Batch, Length, …)
speech_lengths – (Batch, )
text – (Batch, Length)
text_lengths – (Batch,)
text_prev – (Batch, Length)
text_prev_lengths – (Batch,)
text_ctc – (Batch, Length)
text_ctc_lengths – (Batch,)
kwargs – “utt_id” is among the input.
-