Source code for espnet2.enh.layers.tcn

# Implementation of the TCN proposed in
# Luo. et al.  "Conv-tasnet: Surpassing ideal time–frequency
# magnitude masking for speech separation."
#
# The code is based on:
# https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py
# Licensed under MIT.
#


import torch
import torch.nn as nn
import torch.nn.functional as F

from espnet2.enh.layers.adapt_layers import make_adapt_layer

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class TemporalConvNet(nn.Module): def __init__( self, N, B, H, P, X, R, C, Sc=None, out_channel=None, norm_type="gLN", causal=False, pre_mask_nonlinear="linear", mask_nonlinear="relu", ): """Basic Module of tasnet. Args: N: Number of filters in autoencoder B: Number of channels in bottleneck 1 * 1-conv block H: Number of channels in convolutional blocks P: Kernel size in convolutional blocks X: Number of convolutional blocks in each repeat R: Number of repeats C: Number of speakers Sc: Number of channels in skip-connection paths' 1x1-conv blocks out_channel: Number of output channels if it is None, `N` will be used instead. norm_type: BN, gLN, cLN causal: causal or non-causal pre_mask_nonlinear: the non-linear function before masknet mask_nonlinear: use which non-linear function to generate mask """ super().__init__() # Hyper-parameter self.C = C self.mask_nonlinear = mask_nonlinear self.skip_connection = Sc is not None self.out_channel = N if out_channel is None else out_channel if self.skip_connection: assert Sc == B, (Sc, B) # Components # [M, N, K] -> [M, N, K] layer_norm = ChannelwiseLayerNorm(N) # [M, N, K] -> [M, B, K] bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) # [M, B, K] -> [M, B, K] repeats = [] self.receptive_field = 0 for r in range(R): blocks = [] for x in range(X): dilation = 2**x if r == 0 and x == 0: self.receptive_field += P else: self.receptive_field += (P - 1) * dilation padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 blocks += [ TemporalBlock( B, H, Sc, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal, ) ] repeats += [nn.Sequential(*blocks)] temporal_conv_net = nn.Sequential(*repeats) # [M, B, K] -> [M, C*N, K] mask_conv1x1 = nn.Conv1d(B, C * self.out_channel, 1, bias=False) # Put together (for compatibility with older versions) if pre_mask_nonlinear == "linear": self.network = nn.Sequential( layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1 ) else: activ = { "prelu": nn.PReLU(), "relu": nn.ReLU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), }[pre_mask_nonlinear] self.network = nn.Sequential( layer_norm, bottleneck_conv1x1, temporal_conv_net, activ, mask_conv1x1 )
[docs] def forward(self, mixture_w): """Keep this API same with TasNet. Args: mixture_w: [M, N, K], M is batch size Returns: est_mask: [M, C, N, K] """ M, N, K = mixture_w.size() bottleneck = self.network[:2] tcns = self.network[2] masknet = self.network[3:] output = bottleneck(mixture_w) skip_conn = 0.0 for block in tcns: for layer in block: tcn_out = layer(output) if self.skip_connection: residual, skip = tcn_out skip_conn = skip_conn + skip else: residual = tcn_out output = output + residual # Use residual output when no skip connection if self.skip_connection: score = masknet(skip_conn) else: score = masknet(output) # [M, C*self.out_channel, K] -> [M, C, self.out_channel, K] score = score.view(M, self.C, self.out_channel, K) if self.mask_nonlinear == "softmax": est_mask = torch.softmax(score, dim=1) elif self.mask_nonlinear == "relu": est_mask = torch.relu(score) elif self.mask_nonlinear == "sigmoid": est_mask = torch.sigmoid(score) elif self.mask_nonlinear == "tanh": est_mask = torch.tanh(score) elif self.mask_nonlinear == "linear": est_mask = score else: raise ValueError("Unsupported mask non-linear function") return est_mask
[docs]class TemporalConvNetInformed(TemporalConvNet): def __init__( self, N, B, H, P, X, R, Sc=None, out_channel=None, norm_type="gLN", causal=False, pre_mask_nonlinear="prelu", mask_nonlinear="relu", i_adapt_layer: int = 7, adapt_layer_type: str = "mul", adapt_enroll_dim: int = 128, **adapt_layer_kwargs ): """Basic Module of TasNet with adaptation layers. Args: N: Number of filters in autoencoder B: Number of channels in bottleneck 1 * 1-conv block H: Number of channels in convolutional blocks P: Kernel size in convolutional blocks X: Number of convolutional blocks in each repeat R: Number of repeats Sc: Number of channels in skip-connection paths' 1x1-conv blocks out_channel: Number of output channels if it is None, `N` will be used instead. norm_type: BN, gLN, cLN causal: causal or non-causal pre_mask_nonlinear: the non-linear function before masknet mask_nonlinear: use which non-linear function to generate mask i_adapt_layer: int, index of the adaptation layer adapt_layer_type: str, type of adaptation layer see espnet2.enh.layers.adapt_layers for options adapt_enroll_dim: int, dimensionality of the speaker embedding """ super().__init__( N, B, H, P, X, R, 1, Sc=Sc, out_channel=out_channel, norm_type=norm_type, causal=causal, pre_mask_nonlinear=pre_mask_nonlinear, mask_nonlinear=mask_nonlinear, ) self.i_adapt_layer = i_adapt_layer self.adapt_enroll_dim = adapt_enroll_dim self.adapt_layer_type = adapt_layer_type self.adapt_layer = make_adapt_layer( adapt_layer_type, indim=B, enrolldim=adapt_enroll_dim, ninputs=2 if self.skip_connection else 1, **adapt_layer_kwargs )
[docs] def forward(self, mixture_w, enroll_emb): """TasNet forward with adaptation layers. Args: mixture_w: [M, N, K], M is batch size enroll_emb: [M, 2*adapt_enroll_dim] if self.skip_connection [M, adapt_enroll_dim] if not self.skip_connection Returns: est_mask: [M, N, K] """ M, N, K = mixture_w.size() bottleneck = self.network[:2] tcns = self.network[2] masknet = self.network[3:] output = bottleneck(mixture_w) skip_conn = 0.0 for i, block in enumerate(tcns): for j, layer in enumerate(block): idx = i * len(block) + j is_adapt_layer = idx == self.i_adapt_layer tcn_out = layer(output) if self.skip_connection: residual, skip = tcn_out if is_adapt_layer: residual, skip = self.adapt_layer( (residual, skip), torch.chunk(enroll_emb, 2, dim=1) ) skip_conn = skip_conn + skip else: residual = tcn_out if is_adapt_layer: residual = self.adapt_layer(residual, enroll_emb) output = output + residual # Use residual output when no skip connection if self.skip_connection: score = masknet(skip_conn) else: score = masknet(output) # [M, self.out_channel, K] if self.mask_nonlinear == "softmax": est_mask = F.softmax(score, dim=1) elif self.mask_nonlinear == "relu": est_mask = F.relu(score) elif self.mask_nonlinear == "sigmoid": est_mask = F.sigmoid(score) elif self.mask_nonlinear == "tanh": est_mask = F.tanh(score) elif self.mask_nonlinear == "linear": est_mask = score else: raise ValueError("Unsupported mask non-linear function") return est_mask
[docs]class TemporalBlock(nn.Module): def __init__( self, in_channels, out_channels, skip_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False, ): super().__init__() self.skip_connection = skip_channels is not None # [M, B, K] -> [M, H, K] conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) prelu = nn.PReLU() norm = choose_norm(norm_type, out_channels) # [M, H, K] -> [M, B, K] dsconv = DepthwiseSeparableConv( out_channels, in_channels, skip_channels, kernel_size, stride, padding, dilation, norm_type, causal, ) # Put together self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
[docs] def forward(self, x): """Forward. Args: x: [M, B, K] Returns: [M, B, K] """ if self.skip_connection: res_out, skip_out = self.net(x) return res_out, skip_out else: res_out = self.net(x) return res_out
[docs]class DepthwiseSeparableConv(nn.Module): def __init__( self, in_channels, out_channels, skip_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False, ): super().__init__() # Use `groups` option to implement depthwise convolution # [M, H, K] -> [M, H, K] depthwise_conv = nn.Conv1d( in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False, ) if causal: chomp = Chomp1d(padding) prelu = nn.PReLU() norm = choose_norm(norm_type, in_channels) # [M, H, K] -> [M, B, K] pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) # Put together if causal: self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) else: self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) # skip connection if skip_channels is not None: self.skip_conv = nn.Conv1d(in_channels, skip_channels, 1, bias=False) else: self.skip_conv = None
[docs] def forward(self, x): """Forward. Args: x: [M, H, K] Returns: res_out: [M, B, K] skip_out: [M, Sc, K] """ shared_block = self.net[:-1] shared = shared_block(x) res_out = self.net[-1](shared) if self.skip_conv is None: return res_out skip_out = self.skip_conv(shared) return res_out, skip_out
[docs]class Chomp1d(nn.Module): """To ensure the output length is the same as the input.""" def __init__(self, chomp_size): super().__init__() self.chomp_size = chomp_size
[docs] def forward(self, x): """Forward. Args: x: [M, H, Kpad] Returns: [M, H, K] """ return x[:, :, : -self.chomp_size].contiguous()
[docs]def check_nonlinear(nolinear_type): if nolinear_type not in ["softmax", "relu"]: raise ValueError("Unsupported nonlinear type")
[docs]def choose_norm(norm_type, channel_size, shape="BDT"): """The input of normalization will be (M, C, K), where M is batch size. C is channel size and K is sequence length. """ if norm_type == "gLN": return GlobalLayerNorm(channel_size, shape=shape) elif norm_type == "cLN": return ChannelwiseLayerNorm(channel_size, shape=shape) elif norm_type == "BN": # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics # along M and K, so this BN usage is right. return nn.BatchNorm1d(channel_size) elif norm_type == "GN": return nn.GroupNorm(1, channel_size, eps=1e-8) else: raise ValueError("Unsupported normalization type")
[docs]class ChannelwiseLayerNorm(nn.Module): """Channel-wise Layer Normalization (cLN).""" def __init__(self, channel_size, shape="BDT"): super().__init__() self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.reset_parameters() assert shape in ["BDT", "BTD"] self.shape = shape
[docs] def reset_parameters(self): self.gamma.data.fill_(1) self.beta.data.zero_()
[docs] @torch.cuda.amp.autocast(enabled=False) def forward(self, y): """Forward. Args: y: [M, N, K], M is batch size, N is channel size, K is length Returns: cLN_y: [M, N, K] """ assert y.dim() == 3 if self.shape == "BTD": y = y.transpose(1, 2).contiguous() mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta if self.shape == "BTD": cLN_y = cLN_y.transpose(1, 2).contiguous() return cLN_y
[docs]class GlobalLayerNorm(nn.Module): """Global Layer Normalization (gLN).""" def __init__(self, channel_size, shape="BDT"): super().__init__() self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] self.reset_parameters() assert shape in ["BDT", "BTD"] self.shape = shape
[docs] def reset_parameters(self): self.gamma.data.fill_(1) self.beta.data.zero_()
[docs] @torch.cuda.amp.autocast(enabled=False) def forward(self, y): """Forward. Args: y: [M, N, K], M is batch size, N is channel size, K is length Returns: gLN_y: [M, N, K] """ if self.shape == "BTD": y = y.transpose(1, 2).contiguous() mean = y.mean(dim=(1, 2), keepdim=True) # [M, 1, 1] var = (torch.pow(y - mean, 2)).mean(dim=(1, 2), keepdim=True) gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta if self.shape == "BTD": gLN_y = gLN_y.transpose(1, 2).contiguous() return gLN_y