espnet2.enh.diffusion_enh.ESPnetDiffusionModel
espnet2.enh.diffusion_enh.ESPnetDiffusionModel
class espnet2.enh.diffusion_enh.ESPnetDiffusionModel(encoder: AbsEncoder, diffusion: AbsDiffusion, decoder: AbsDecoder, num_spk: int = 1, normalize: bool = False, **kwargs)
Bases: ESPnetEnhancementModel
Target Speaker Extraction Frontend model
Main entry of speech enhancement/separation model training.
- Parameters:
encoder – waveform encoder that converts waveforms to feature representations
separator – separator that enhance or separate the feature representations
decoder – waveform decoder that converts the feature back to waveforms
mask_module – mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details.
loss_wrappers – list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up.
------------------------------------------------------------------
stft_consistency – (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses.
loss_type – (deprecated, kept for compatibility) loss type
mask_type – (deprecated, kept for compatibility) mask type in TF-domain model
------------------------------------------------------------------
flexible_numspk – whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers.
------------------------------------------------------------------
extract_feats_in_collect_stats – used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/
*
/enh1/enh.sh).
normalize_variance – whether to normalize the signal variance before model forward, and revert it back after.
normalize_variance_per_ch – whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time.
------------------------------------------------------------------
categories – list of all possible categories of minibatches (order matters!) (e.g. [“1ch_8k_reverb”, “1ch_8k_both”] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes.
category_weights – list of weights for each category. Used to set loss weights for batches of different categories.
------------------------------------------------------------------
always_forward_in_48k – whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) → Dict[str, Tensor]
enhance(feature_mix)
forward(speech_mix: Tensor, speech_mix_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech_mix – (Batch, samples) or (Batch, samples, channels)
- speech_ref1 – (Batch, samples) or (Batch, samples, channels)
- speech_ref2 – (Batch, samples) or (Batch, samples, channels)
- ...
- speech_mix_lengths – (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py
- enroll_ref1 – (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 1
- enroll_ref2 – (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 2
- ...
- kwargs – “utt_id” is among the input.
forward_loss(speech_ref, speech_mix, speech_lengths) → Tuple[Tensor, Dict[str, Tensor], Tensor]