espnet2.asr.maskctc_model.MaskCTCModel
espnet2.asr.maskctc_model.MaskCTCModel
class espnet2.asr.maskctc_model.MaskCTCModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, decoder: MLMDecoder, ctc: CTC, joint_network: Module | None = None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_mask: str = '<mask>', extract_feats_in_collect_stats: bool = True)
Bases: ESPnetASRModel
Hybrid CTC/Masked LM Encoder-Decoder model (Mask-CTC)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
batchify_nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor, batch_size: int = 100)
Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches. Then call nll for each batch and combine and return results. :param encoder_out: (Batch, Length, Dim) :param encoder_out_lens: (Batch,) :param ys_pad: (Batch, Length) :param ys_pad_lens: (Batch,) :param batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase GPU memory usage
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech – (Batch, Length, …)
- speech_lengths – (Batch, )
- text – (Batch, Length)
- text_lengths – (Batch,)
nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor) → Tensor
Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
- Parameters:
- encoder_out – (Batch, Length, Dim)
- encoder_out_lens – (Batch,)
- ys_pad – (Batch, Length)
- ys_pad_lens – (Batch,)