espnet2.enh.espnet_model.ESPnetEnhancementModel
espnet2.enh.espnet_model.ESPnetEnhancementModel
class espnet2.enh.espnet_model.ESPnetEnhancementModel(encoder: AbsEncoder, separator: AbsSeparator | None, decoder: AbsDecoder, mask_module: AbsMask | None, loss_wrappers: List[AbsLossWrapper] | None, stft_consistency: bool = False, loss_type: str = 'mask_mse', mask_type: str | None = None, flexible_numspk: bool = False, extract_feats_in_collect_stats: bool = False, normalize_variance: bool = False, normalize_variance_per_ch: bool = False, categories: list = [], category_weights: list = [], always_forward_in_48k: bool = False)
Bases: AbsESPnetModel
Speech enhancement or separation Frontend model
Main entry of speech enhancement/separation model training.
- Parameters:
encoder – waveform encoder that converts waveforms to feature representations
separator – separator that enhance or separate the feature representations
decoder – waveform decoder that converts the feature back to waveforms
mask_module – mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details.
loss_wrappers – list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up.
------------------------------------------------------------------
stft_consistency – (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
loss_type – (deprecated, kept for compatibility) loss type
mask_type – (deprecated, kept for compatibility) mask type in TF-domain model
------------------------------------------------------------------
flexible_numspk – whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers.
------------------------------------------------------------------
extract_feats_in_collect_stats – used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/
*
/enh1/enh.sh).
normalize_variance – whether to normalize the signal variance before model forward, and revert it back after.
normalize_variance_per_ch – whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
------------------------------------------------------------------
categories – list of all possible categories of minibatches (order matters!) (e.g. [“1ch_8k_reverb”, “1ch_8k_both”] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes.
category_weights – list of weights for each category. Used to set loss weights for batches of different categories.
------------------------------------------------------------------
always_forward_in_48k – whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) → Dict[str, Tensor]
forward(speech_mix: Tensor, speech_mix_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech_mix – (Batch, samples) or (Batch, samples, channels)
- speech_ref – (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
- speech_mix_lengths – (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py
- kwargs – “utt_id” is among the input.
forward_enhance(speech_mix: Tensor, speech_lengths: Tensor, additional: Dict | None = None, fs: int | None = None) → Tuple[Tensor, Tensor, Tensor]
forward_loss(speech_pre: Tensor, speech_lengths: Tensor, feature_mix: Tensor, feature_pre: List[Tensor], others: OrderedDict, speech_ref: List[Tensor], noise_ref: List[Tensor] | None = None, dereverb_speech_ref: List[Tensor] | None = None, category: Tensor | None = None, num_spk: int | None = None, fs: int | None = None) → Tuple[Tensor, Dict[str, Tensor], Tensor]
static sort_by_perm(nn_output, perm)
Sort the input list of tensors by the specified permutation.
- Parameters:
- nn_output – List[torch.Tensor(Batch, …)], len(nn_output) == num_spk
- perm – (Batch, num_spk) or List[torch.Tensor(num_spk)]
- Returns: List[torch.Tensor(Batch, …)]
- Return type: nn_output_new