espnet2.enh.separator.asteroid_models.AsteroidModel_Converter
espnet2.enh.separator.asteroid_models.AsteroidModel_Converter
class espnet2.enh.separator.asteroid_models.AsteroidModel_Converter(encoder_output_dim: int, model_name: str, num_spk: int, pretrained_path: str = '', loss_type: str = 'si_snr', **model_related_kwargs)
Bases: AbsSeparator
The class to convert the models from asteroid to AbsSeprator.
- Parameters:
- encoder_output_dim – input feature dimension, default=1 after the NullEncoder
- num_spk – number of speakers
- loss_type – loss type of enhancement
- model_name – Asteroid model names, e.g. ConvTasNet, DPTNet. Refers to https://github.com/asteroid-team/asteroid/ blob/master/asteroid/models/_init_.py
- pretrained_path – the name of pretrained model from Asteroid in HF hub. Refers to: https://github.com/asteroid-team/asteroid/ blob/master/docs/source/readmes/pretrained_models.md and https://huggingface.co/models?filter=asteroid
- model_related_kwargs – more args towards each specific asteroid model.
forward(input: Tensor, ilens: Tensor | None = None, additional: Dict | None = None)
Whole forward of asteroid models.
Parameters:
- input (torch.Tensor) – Raw Waveforms [B, T]
- ilens (torch.Tensor) – input lengths [B]
- additional (Dict or None) – other data included in model
Returns: [(B, T), …] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[
’mask_spk1’: torch.Tensor(Batch, T), ‘mask_spk2’: torch.Tensor(Batch, T), … ‘mask_spkn’: torch.Tensor(Batch, T),
]
Return type: estimated Waveforms(List[Union(torch.Tensor])
forward_rawwav(input: Tensor, ilens: Tensor | None = None) → Tuple[Tensor, Tensor]
Output with waveforms.
property num_spk