Source code for espnet2.enh.separator.tcn_separator

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.layers.tcn import TemporalConvNet
from espnet2.enh.separator.abs_separator import AbsSeparator

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

[docs]class TCNSeparator(AbsSeparator): def __init__( self, input_dim: int, num_spk: int = 2, predict_noise: bool = False, layer: int = 8, stack: int = 3, bottleneck_dim: int = 128, hidden_dim: int = 512, kernel: int = 3, causal: bool = False, norm_type: str = "gLN", nonlinear: str = "relu", ): """Temporal Convolution Separator Args: input_dim: input feature dimension num_spk: number of speakers predict_noise: whether to output the estimated noise signal layer: int, number of layers in each stack. stack: int, number of stacks bottleneck_dim: bottleneck dimension hidden_dim: number of convolution channel kernel: int, kernel size. causal: bool, defalut False. norm_type: str, choose from 'BN', 'gLN', 'cLN' nonlinear: the nonlinear function for mask estimation, select from 'relu', 'tanh', 'sigmoid' """ super().__init__() self._num_spk = num_spk self.predict_noise = predict_noise if nonlinear not in ("sigmoid", "relu", "tanh"): raise ValueError("Not supporting nonlinear={}".format(nonlinear)) self.tcn = TemporalConvNet( N=input_dim, B=bottleneck_dim, H=hidden_dim, P=kernel, X=layer, R=stack, C=num_spk + 1 if predict_noise else num_spk, norm_type=norm_type, causal=causal, mask_nonlinear=nonlinear, )
[docs] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum if is_complex(input): feature = abs(input) else: feature = input B, L, N = feature.shape feature = feature.transpose(1, 2) # B, N, L masks = self.tcn(feature) # B, num_spk, N, L masks = masks.transpose(2, 3) # B, num_spk, L, N if self.predict_noise: *masks, mask_noise = masks.unbind(dim=1) # List[B, L, N] else: masks = masks.unbind(dim=1) # List[B, L, N] masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) if self.predict_noise: others["noise1"] = input * mask_noise return masked, ilens, others
[docs] def forward_streaming(self, input_frame: torch.Tensor, buffer=None): # input_frame: B, 1, N B, _, N = input_frame.shape receptive_field = self.tcn.receptive_field if buffer is None: buffer = torch.zeros((B, receptive_field, N), device=input_frame.device) buffer = torch.roll(buffer, shifts=-1, dims=1) buffer[:, -1, :] = input_frame[:, 0, :] masked, ilens, others = self.forward(buffer, None) masked = [m[:, -1, :].unsqueeze(1) for m in masked] return masked, buffer, others
@property def num_spk(self): return self._num_spk