espnet2.enh.layers.dnn_wpe.DNN_WPE
espnet2.enh.layers.dnn_wpe.DNN_WPE
class espnet2.enh.layers.dnn_wpe.DNN_WPE(wtype: str = 'blstmp', widim: int = 257, wlayers: int = 3, wunits: int = 300, wprojs: int = 320, dropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask: bool = True, nmask: int = 1, nonlinear: str = 'sigmoid', iterations: int = 1, normalization: bool = False, eps: float = 1e-06, diagonal_loading: bool = True, diag_eps: float = 1e-07, mask_flooring: bool = False, flooring_thres: float = 1e-06, use_torch_solver: bool = True)
Bases: Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(data: Tensor | ComplexTensor, ilens: LongTensor) → Tuple[Tensor | ComplexTensor, LongTensor, Tensor | ComplexTensor]
DNN_WPE forward function.
Notation: : B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector
- Parameters:
- data – (B, T, C, F)
- ilens – (B,)
- Returns: (B, T, C, F) ilens: (B,) masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) power (List[torch.Tensor]): (B, F, T)
- Return type: enhanced (torch.Tensor or List[torch.Tensor])
predict_mask(data: Tensor | ComplexTensor, ilens: LongTensor) → Tuple[Tensor, LongTensor]
Predict mask for WPE dereverberation.
- Parameters:
- data (torch.complex64/ComplexTensor) – (B, T, C, F), double precision
- ilens (torch.Tensor) – (B,)
- Returns: (B, T, C, F) ilens (torch.Tensor): (B,)
- Return type: masks (torch.Tensor or List[torch.Tensor])