Source code for espnet2.spk.layers.rawnet_block

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class AFMS(nn.Module): """Alpha-Feature map scaling, added to the output of each residual block[1,2]. Reference: [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page """ def __init__(self, nb_dim: int) -> None: super().__init__() self.alpha = nn.Parameter(torch.ones((nb_dim, 1))) self.fc = nn.Linear(nb_dim, nb_dim) self.sig = nn.Sigmoid()
[docs] def forward(self, x): y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1) y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1) x = x + self.alpha x = x * y return x
[docs]class Bottle2neck(nn.Module): def __init__( self, inplanes, planes, kernel_size=None, dilation=None, scale=4, pool=False ): super().__init__() width = int(math.floor(planes / scale)) self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) self.bn1 = nn.BatchNorm1d(width * scale) self.nums = scale - 1 convs = [] bns = [] num_pad = math.floor(kernel_size / 2) * dilation for i in range(self.nums): convs.append( nn.Conv1d( width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad, ) ) bns.append(nn.BatchNorm1d(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) self.bn3 = nn.BatchNorm1d(planes) self.relu = nn.ReLU() self.width = width self.mp = nn.MaxPool1d(pool) if pool else False self.afms = AFMS(planes) if inplanes != planes: # if change in number of filters self.residual = nn.Sequential( nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False) ) else: self.residual = nn.Identity()
[docs] def forward(self, x): residual = self.residual(x) out = self.conv1(x) out = self.relu(out) out = self.bn1(out) spx = torch.split(out, self.width, 1) for i in range(self.nums): if i == 0: sp = spx[i] else: sp = sp + spx[i] sp = self.convs[i](sp) sp = self.relu(sp) sp = self.bns[i](sp) if i == 0: out = sp else: out = torch.cat((out, sp), 1) out = torch.cat((out, spx[self.nums]), 1) out = self.conv3(out) out = self.relu(out) out = self.bn3(out) out += residual if self.mp: out = self.mp(out) out = self.afms(out) return out