espnet2.enh.separator.skim_separator.SkiMSeparator
espnet2.enh.separator.skim_separator.SkiMSeparator
class espnet2.enh.separator.skim_separator.SkiMSeparator(input_dim: int, causal: bool = True, num_spk: int = 2, predict_noise: bool = False, nonlinear: str = 'relu', layer: int = 3, unit: int = 512, segment_size: int = 20, dropout: float = 0.0, mem_type: str = 'hc', seg_overlap: bool = False)
Bases: AbsSeparator
Skipping Memory (SkiM) Separator
- Parameters:
- input_dim – input feature dimension
- causal – bool, whether the system is causal.
- num_spk – number of target speakers.
- nonlinear – the nonlinear function for mask estimation, select from ‘relu’, ‘tanh’, ‘sigmoid’
- layer – int, number of SkiM blocks. Default is 3.
- unit – int, dimension of the hidden state.
- segment_size – segmentation size for splitting long features
- dropout – float, dropout ratio. Default is 0.
- mem_type – ‘hc’, ‘h’, ‘c’, ‘id’ or None. It controls whether the hidden (or cell) state of SegLSTM will be processed by MemLSTM. In ‘id’ mode, both the hidden and cell states will be identically returned. When mem_type is None, the MemLSTM will be removed.
- seg_overlap – Bool, whether the segmentation will reserve 50% overlap for adjacent segments. Default is False.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor | ComplexTensor, ilens: Tensor, additional: Dict | None = None) → Tuple[List[Tensor | ComplexTensor], Tensor, OrderedDict]
Forward.
Parameters:
- input (torch.Tensor or ComplexTensor) – Encoded feature [B, T, N]
- ilens (torch.Tensor) – input lengths [Batch]
- additional (Dict or None) – other data included in model NOTE: not used in this model
Returns: [(B, T, N), …] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[
’mask_spk1’: torch.Tensor(Batch, Frames, Freq), ‘mask_spk2’: torch.Tensor(Batch, Frames, Freq), … ‘mask_spkn’: torch.Tensor(Batch, Frames, Freq),
]
Return type: masked (List[Union(torch.Tensor, ComplexTensor)])
forward_streaming(input_frame: Tensor, states=None)
property num_spk