Source code for espnet2.spk.encoder.ska_tdnn_encoder

# SKA-TDNN, original code from: https://github.com/msh9184/ska-tdnn
# adapted for ESPnet-SPK by Jee-weon Jung
import math
from collections import OrderedDict

import torch
import torch.nn as nn
from typeguard import typechecked

from espnet2.asr.encoder.abs_encoder import AbsEncoder


[docs]class SEModule(nn.Module): def __init__(self, channels, bottleneck=128): super(SEModule, self).__init__() self.se = nn.Sequential( nn.AdaptiveAvgPool1d(1), nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0), nn.ReLU(), nn.BatchNorm1d(bottleneck), nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0), nn.Sigmoid(), )
[docs] def forward(self, input): x = self.se(input) return input * x
[docs]class Bottle2neck(nn.Module): def __init__( self, inplanes, planes, kernel_size=None, kernel_sizes=[5, 7], dilation=None, scale=8, group=1, ): super(Bottle2neck, self).__init__() width = int(math.floor(planes / scale)) self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) self.relu = nn.ReLU() self.bn1 = nn.BatchNorm1d(width * scale) self.nums = scale - 1 self.skconvs = nn.ModuleList([]) for i in range(self.nums): convs = nn.ModuleList([]) for k in kernel_sizes: convs += [ nn.Sequential( OrderedDict( [ ( "conv", nn.Conv1d( width, width, kernel_size=k, dilation=dilation, padding=k // 2 * dilation, groups=group, ), ), ("relu", nn.ReLU()), ("bn", nn.BatchNorm1d(width)), ] ) ) ] self.skconvs += [convs] self.skse = SKAttentionModule( channel=width, reduction=4, num_kernels=len(kernel_sizes) ) self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) self.bn3 = nn.BatchNorm1d(planes) self.relu = nn.ReLU() self.se = SEModule(channels=planes) self.width = width
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.relu(out) out = self.bn1(out) spx = torch.split(out, self.width, 1) for i in range(self.nums): if i == 0: sp = spx[i] else: sp = sp + spx[i] sp = self.skse(sp, self.skconvs[i]) if i == 0: out = sp else: out = torch.cat((out, sp), 1) out = torch.cat((out, spx[self.nums]), 1) out = self.conv3(out) out = self.relu(out) out = self.bn3(out) out = self.se(out) out += residual return out
[docs]class ResBlock(nn.Module): def __init__( self, inplanes: int, planes: int, stride: int = 1, reduction: int = 8, skfwse_freq: int = 40, skcwse_channel: int = 128, ): super(ResBlock, self).__init__() self.conv1 = nn.Conv2d( inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.skfwse = fwSKAttention( freq=skfwse_freq, channel=skcwse_channel, kernels=[5, 7], receptive=[5, 7], dilations=[1, 1], reduction=reduction, groups=1, ) self.skcwse = cwSKAttention( freq=skfwse_freq, channel=skcwse_channel, kernels=[5, 7], receptive=[5, 7], dilations=[1, 1], reduction=reduction, groups=1, ) self.stride = stride
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.relu(out) out = self.bn1(out) out = self.skfwse(out) out = self.skcwse(out) out += residual out = self.relu(out) return out
[docs]class SKAttentionModule(nn.Module): def __init__(self, channel=128, reduction=4, L=16, num_kernels=2): super(SKAttentionModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool1d(1) self.D = max(L, channel // reduction) self.fc = nn.Linear(channel, self.D) self.relu = nn.ReLU() self.fcs = nn.ModuleList([]) for i in range(num_kernels): self.fcs += [nn.Linear(self.D, channel)] self.softmax = nn.Softmax(dim=0)
[docs] def forward(self, x, convs): """Forward function. Input: [B, C, T] Split: [K, B, C, T] Fues: [B, C, T] Attention weight: [B, C, 1] Output: [B, C, T] """ bs, c, t = x.size() conv_outs = [] for conv in convs: conv_outs += [conv(x)] feats = torch.stack(conv_outs, 0) U = sum(conv_outs) S = self.avg_pool(U).view(bs, c) Z = self.fc(S) Z = self.relu(Z) weights = [] for fc in self.fcs: weight = fc(Z) weights += [(weight.view(bs, c, 1))] attention_weights = torch.stack(weights, 0) attention_weights = self.softmax(attention_weights) V = (attention_weights * feats).sum(0) return V
[docs]class fwSKAttention(nn.Module): def __init__( self, freq=40, channel=128, kernels=[3, 5], receptive=[3, 5], dilations=[1, 1], reduction=8, groups=1, L=16, ): super(fwSKAttention, self).__init__() self.convs = nn.ModuleList([]) for k, d, r in zip(kernels, dilations, receptive): self.convs += [ nn.Sequential( OrderedDict( [ ( "conv", nn.Conv2d( channel, channel, kernel_size=k, padding=r // 2, dilation=d, groups=groups, ), ), ("relu", nn.ReLU()), ("bn", nn.BatchNorm2d(channel)), ] ) ) ] self.avg_pool = nn.AdaptiveAvgPool2d(1) self.D = max(L, freq // reduction) self.fc = nn.Linear(freq, self.D) self.relu = nn.ReLU() self.fcs = nn.ModuleList([]) for i in range(len(kernels)): self.fcs += [nn.Linear(self.D, freq)] self.softmax = nn.Softmax(dim=0)
[docs] def forward(self, x): """Forward function. Input: [B, C, F, T] Split: [K, B, C, F, T] Fues: [B, C, F, T] Attention weight: [K, B, 1, F, 1] Output: [B, C, F, T] """ bs, c, f, t = x.size() conv_outs = [] for conv in self.convs: conv_outs += [conv(x)] feats = torch.stack(conv_outs, 0) U = sum(conv_outs).permute(0, 2, 3, 1) S = self.avg_pool(U).view(bs, f) Z = self.fc(S) Z = self.relu(Z) weights = [] for fc in self.fcs: weight = fc(Z) weights += [(weight.view(bs, 1, f, 1))] attention_weights = torch.stack(weights, 0) attention_weights = self.softmax(attention_weights) V = (attention_weights * feats).sum(0) return V
[docs]class cwSKAttention(nn.Module): def __init__( self, freq=40, channel=128, kernels=[3, 5], receptive=[3, 5], dilations=[1, 1], reduction=8, groups=1, L=16, ): super(cwSKAttention, self).__init__() self.convs = nn.ModuleList([]) for k, d, r in zip(kernels, dilations, receptive): self.convs += [ nn.Sequential( OrderedDict( [ ( "conv", nn.Conv2d( channel, channel, kernel_size=k, padding=r // 2, dilation=d, groups=groups, ), ), ("relu", nn.ReLU()), ("bn", nn.BatchNorm2d(channel)), ] ) ) ] self.avg_pool = nn.AdaptiveAvgPool2d(1) self.D = max(L, channel // reduction) self.fc = nn.Linear(channel, self.D) self.relu = nn.ReLU() self.fcs = nn.ModuleList([]) for i in range(len(kernels)): self.fcs += [nn.Linear(self.D, channel)] self.softmax = nn.Softmax(dim=0)
[docs] def forward(self, x): """Forward Function. Input: [B, C, F, T] Split: [K, B, C, F, T] Fuse: [B, C, F, T] Attention weight: [K, B, C, 1, 1] Output: [B, C, F, T] """ bs, c, f, t = x.size() conv_outs = [] for conv in self.convs: conv_outs += [conv(x)] feats = torch.stack(conv_outs, 0) U = sum(conv_outs) S = self.avg_pool(U).view(bs, c) Z = self.fc(S) Z = self.relu(Z) weights = [] for fc in self.fcs: weight = fc(Z) weights += [(weight.view(bs, c, 1, 1))] attention_weights = torch.stack(weights, 0) attention_weights = self.softmax(attention_weights) V = (attention_weights * feats).sum(0) return V
[docs]class SkaTdnnEncoder(AbsEncoder): """SKA-TDNN encoder. Extracts frame-level SKA-TDNN embeddings from features. Paper: S. Mun, J. Jung et al., "Frequency and Multi-Scale Selective Kernel Attention for Speaker Verification,' in Proc. IEEE SLT 2022. Args: input_size: input feature dimension. block: type of encoder block class to use. model_scale: scale value of the Res2Net architecture. ndim: dimensionality of the hidden representation. output_size: ouptut embedding dimension. """ @typechecked def __init__( self, input_size: int, block: str = "Bottle2neck", ndim: int = 1024, model_scale: int = 8, skablock: str = "ResBlock", ska_dim: int = 128, output_size: int = 1536, **kwargs, ): super().__init__() if block == "Bottle2neck": block: type = Bottle2neck else: raise ValueError(f"unsupported block, got: {block}") if skablock == "ResBlock": ska_block = ResBlock else: raise ValueError(f"unsupported block, got: {ska_block}") self.frt_conv1 = nn.Conv2d( 1, ska_dim, kernel_size=(3, 3), stride=(2, 1), padding=1 ) self.frt_bn1 = nn.BatchNorm2d(ska_dim) self.frt_block1 = ska_block( ska_dim, ska_dim, stride=(1, 1), skfwse_freq=input_size // 2, skcwse_channel=ska_dim, ) self.frt_block2 = ska_block( ska_dim, ska_dim, stride=(1, 1), skfwse_freq=input_size // 2, skcwse_channel=ska_dim, ) self.frt_conv2 = nn.Conv2d( ska_dim, ska_dim, kernel_size=(3, 3), stride=(2, 2), padding=1 ) self.frt_bn2 = nn.BatchNorm2d(ska_dim) self.conv1 = nn.Conv1d( ska_dim * input_size // 4, ndim, kernel_size=5, stride=1, padding=2 ) self.relu = nn.ReLU() self.bn1 = nn.BatchNorm1d(ndim) self.layer1 = block(ndim, ndim, kernel_size=3, dilation=2, scale=model_scale) self.layer2 = block(ndim, ndim, kernel_size=3, dilation=3, scale=model_scale) self.layer3 = block(ndim, ndim, kernel_size=3, dilation=4, scale=model_scale) self.layer4 = nn.Conv1d(3 * ndim, output_size, kernel_size=1) self._output_size = output_size
[docs] def output_size(self) -> int: return self._output_size
[docs] def forward(self, x): x = x.permute(0, 2, 1) # (B, S, D) -> (B, D, S) x = x.unsqueeze(1) # (B, D, S) -> (B, 1, D, S) # the fcwSKA block x = self.frt_conv1(x) x = self.relu(x) x = self.frt_bn1(x) x = self.frt_block1(x) x = self.frt_block2(x) x = self.frt_conv2(x) x = self.relu(x) x = self.frt_bn2(x) x = x.reshape((x.size()[0], -1, x.size()[-1])) x = self.conv1(x) x = self.relu(x) x = self.bn1(x) x1 = self.layer1(x) x2 = self.layer2(x + x1) x3 = self.layer3(x + x1 + x2) x = self.layer4(torch.cat((x1, x2, x3), dim=1)) x = self.relu(x) return x