espnet2.enh.espnet_enh_s2t_model.ESPnetEnhS2TModel
espnet2.enh.espnet_enh_s2t_model.ESPnetEnhS2TModel
class espnet2.enh.espnet_enh_s2t_model.ESPnetEnhS2TModel(enh_model: ESPnetEnhancementModel, s2t_model: ESPnetASRModel | ESPnetSTModel | ESPnetDiarizationModel, calc_enh_loss: bool = True, bypass_enh_prob: float = 0)
Bases: AbsESPnetModel
Joint model Enhancement and Speech to Text.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
asr_pit_loss(speech, speech_lengths, text, text_lengths)
batchify_nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor, batch_size: int = 100)
Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches. Then call nll for each batch and combine and return results. :param encoder_out: (Batch, Length, Dim) :param encoder_out_lens: (Batch,) :param ys_pad: (Batch, Length) :param ys_pad_lens: (Batch,) :param batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase GPU memory usage
collect_feats(speech: Tensor, speech_lengths: Tensor, **kwargs) → Dict[str, Tensor]
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by asr_inference.py
- Parameters:
- speech – (Batch, Length, …)
- speech_lengths – (Batch, )
encode_diar(speech: Tensor, speech_lengths: Tensor, num_spk: int) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by diar_inference.py
- Parameters:
- speech – (Batch, Length, …)
- speech_lengths – (Batch, )
- num_spk – int
forward(speech: Tensor, speech_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
speech – (Batch, Length, …)
speech_lengths – (Batch, ) default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py
task (For Enh+ASR) – text_spk1: (Batch, Length) text_spk2: (Batch, Length) … text_spk1_lengths: (Batch,) text_spk2_lengths: (Batch,) …
tasks (For other) –
text: (Batch, Length) default None just to keep the argument order text_lengths: (Batch,)
default None for the same reason as speech_lengths
inherite_attributes(inherite_enh_attrs: List[str] = [], inherite_s2t_attrs: List[str] = [])
nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor) → Tensor
Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
- Parameters:
- encoder_out – (Batch, Length, Dim)
- encoder_out_lens – (Batch,)
- ys_pad – (Batch, Length)
- ys_pad_lens – (Batch,)
permutation_invariant_training(losses: Tensor)
Compute PIT loss.
- Parameters:losses (torch.Tensor) – (batch, nref, nhyp)
- Returns: list: (batch, n_spk) loss: torch.Tensor: (batch)
- Return type: perm