espnet2.enh.espnet_model_tse.ESPnetExtractionModel
espnet2.enh.espnet_model_tse.ESPnetExtractionModel
class espnet2.enh.espnet_model_tse.ESPnetExtractionModel(encoder: AbsEncoder, extractor: AbsExtractor, decoder: AbsDecoder, loss_wrappers: List[AbsLossWrapper], num_spk: int = 1, flexible_numspk: bool = False, share_encoder: bool = True, extract_feats_in_collect_stats: bool = False)
Bases: AbsESPnetModel
Target Speaker Extraction Frontend model
Initializes internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech_mix: Tensor, speech_mix_lengths: Tensor, **kwargs) → Dict[str, Tensor]
forward(speech_mix: Tensor, speech_mix_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
- Parameters:
- speech_mix – (Batch, samples) or (Batch, samples, channels)
- speech_ref1 – (Batch, samples) or (Batch, samples, channels)
- speech_ref2 – (Batch, samples) or (Batch, samples, channels)
- ...
- speech_mix_lengths – (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py
- enroll_ref1 – (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 1
- enroll_ref2 – (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 2
- ...
- kwargs – “utt_id” is among the input.
forward_enhance(speech_mix: Tensor, speech_lengths: Tensor, enroll_ref: Tensor, enroll_ref_lengths: Tensor, additional: Dict | None = None) → Tuple[Tensor, Tensor, Tensor]
forward_loss(speech_pre: Tensor, speech_lengths: Tensor, feature_mix: Tensor, feature_pre: Tensor, others: OrderedDict, speech_ref: Tensor) → Tuple[Tensor, Dict[str, Tensor], Tensor]