Source code for espnet2.diar.layers.multi_mask

# This is an implementation of the multiple 1x1 convolution layer architecture
# in https://arxiv.org/pdf/2203.17068.pdf

from collections import OrderedDict
from typing import List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_complex.tensor import ComplexTensor

from espnet2.diar.layers.abs_mask import AbsMask


[docs]class MultiMask(AbsMask): def __init__( self, input_dim: int, bottleneck_dim: int = 128, max_num_spk: int = 3, mask_nonlinear="relu", ): """Multiple 1x1 convolution layer Module. This module corresponds to the final 1x1 conv block and non-linear function in TCNSeparator. This module has multiple 1x1 conv blocks. One of them is selected according to the given num_spk to handle flexible num_spk. Args: input_dim: Number of filters in autoencoder bottleneck_dim: Number of channels in bottleneck 1 * 1-conv block max_num_spk: Number of mask_conv1x1 modules (>= Max number of speakers in the dataset) mask_nonlinear: use which non-linear function to generate mask """ super().__init__() # Hyper-parameter self._max_num_spk = max_num_spk self.mask_nonlinear = mask_nonlinear # [M, B, K] -> [M, C*N, K] self.mask_conv1x1 = nn.ModuleList() for z in range(1, max_num_spk + 1): self.mask_conv1x1.append( nn.Conv1d(bottleneck_dim, z * input_dim, 1, bias=False) ) @property def max_num_spk(self) -> int: return self._max_num_spk
[docs] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, bottleneck_feat: torch.Tensor, num_spk: int, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Keep this API same with TasNet. Args: input: [M, K, N], M is batch size ilens (torch.Tensor): (M,) bottleneck_feat: [M, K, B] num_spk: number of speakers (Training: oracle, Inference: estimated by other module (e.g, EEND-EDA)) Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(M, K, N), ...] ilens (torch.Tensor): (M,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ M, K, N = input.size() bottleneck_feat = bottleneck_feat.transpose(1, 2) # [M, B, K] score = self.mask_conv1x1[num_spk - 1]( bottleneck_feat ) # [M, B, K] -> [M, num_spk*N, K] # add other outputs of the module list with factor 0.0 # to enable distributed training for z in range(self._max_num_spk): if z != num_spk - 1: score += 0.0 * F.interpolate( self.mask_conv1x1[z](bottleneck_feat).transpose(1, 2), size=num_spk * N, ).transpose(1, 2) score = score.view(M, num_spk, N, K) # [M, num_spk*N, K] -> [M, num_spk, N, K] if self.mask_nonlinear == "softmax": est_mask = F.softmax(score, dim=1) elif self.mask_nonlinear == "relu": est_mask = F.relu(score) elif self.mask_nonlinear == "sigmoid": est_mask = torch.sigmoid(score) elif self.mask_nonlinear == "tanh": est_mask = torch.tanh(score) else: raise ValueError("Unsupported mask non-linear function") masks = est_mask.transpose(2, 3) # [M, num_spk, K, N] masks = masks.unbind(dim=1) # List[M, K, N] masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) return masked, ilens, others