espnet2.enh.layers.bsrnn.MaskDecoder
Less than 1 minute
espnet2.enh.layers.bsrnn.MaskDecoder
class espnet2.enh.layers.bsrnn.MaskDecoder(freq_dim, subbands, channels=128, num_spk=1, norm_type='GN')
Bases: Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
MaskDecoder forward.
- Parameters:x (torch.Tensor) – input tensor of shape (B, N, T, K)
- Returns: output mask of shape (B, num_spk, T, F, 2) r (torch.Tensor): output residual of shape (B, num_spk, T, F, 2)
- Return type: m (torch.Tensor)