espnet2.enh.separator.dptnet_separator.DPTNetSeparator
espnet2.enh.separator.dptnet_separator.DPTNetSeparator
class espnet2.enh.separator.dptnet_separator.DPTNetSeparator(input_dim: int, post_enc_relu: bool = True, rnn_type: str = 'lstm', bidirectional: bool = True, num_spk: int = 2, predict_noise: bool = False, unit: int = 256, att_heads: int = 4, dropout: float = 0.0, activation: str = 'relu', norm_type: str = 'gLN', layer: int = 6, segment_size: int = 20, nonlinear: str = 'relu')
Bases: AbsSeparator
Dual-Path Transformer Network (DPTNet) Separator
- Parameters:
- input_dim – input feature dimension
- rnn_type – string, select from ‘RNN’, ‘LSTM’ and ‘GRU’.
- bidirectional – bool, whether the inter-chunk RNN layers are bidirectional.
- num_spk – number of speakers
- predict_noise – whether to output the estimated noise signal
- unit – int, dimension of the hidden state.
- att_heads – number of attention heads.
- dropout – float, dropout ratio. Default is 0.
- activation – activation function applied at the output of RNN.
- norm_type – type of normalization to use after each inter- or intra-chunk Transformer block.
- nonlinear – the nonlinear function for mask estimation, select from ‘relu’, ‘tanh’, ‘sigmoid’
- layer – int, number of stacked RNN layers. Default is 3.
- segment_size – dual-path segment size
forward(input: Tensor | ComplexTensor, ilens: Tensor, additional: Dict | None = None) → Tuple[List[Tensor | ComplexTensor], Tensor, OrderedDict]
Forward.
Parameters:
- input (torch.Tensor or ComplexTensor) – Encoded feature [B, T, N]
- ilens (torch.Tensor) – input lengths [Batch]
- additional (Dict or None) – other data included in model NOTE: not used in this model
Returns: [(B, T, N), …] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[
’mask_spk1’: torch.Tensor(Batch, Frames, Freq), ‘mask_spk2’: torch.Tensor(Batch, Frames, Freq), … ‘mask_spkn’: torch.Tensor(Batch, Frames, Freq),
]
Return type: masked (List[Union(torch.Tensor, ComplexTensor)])
merge_feature(x, length=None)
property num_spk
split_feature(x)