import math
import torch
import torch.nn as nn
"""
Basic blocks for ECAPA-TDNN.
Code from https://github.com/TaoRuijie/ECAPA-TDNN/blob/main/model.py
"""
[docs]class SEModule(nn.Module):
def __init__(self, channels: int, bottleneck: int = 128):
super().__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
nn.ReLU(),
nn.BatchNorm1d(bottleneck),
nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
nn.Sigmoid(),
)
[docs] def forward(self, input):
x = self.se(input)
return input * x
[docs]class EcapaBlock(nn.Module):
def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
super().__init__()
width = int(math.floor(planes / scale))
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
self.bn1 = nn.BatchNorm1d(width * scale)
self.nums = scale - 1
convs = []
bns = []
num_pad = math.floor(kernel_size / 2) * dilation
for i in range(self.nums):
convs.append(
nn.Conv1d(
width,
width,
kernel_size=kernel_size,
dilation=dilation,
padding=num_pad,
)
)
bns.append(nn.BatchNorm1d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
self.bn3 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU()
self.width = width
self.se = SEModule(planes)
[docs] def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(sp)
sp = self.bns[i](sp)
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = torch.cat((out, spx[self.nums]), 1)
out = self.conv3(out)
out = self.relu(out)
out = self.bn3(out)
out = self.se(out)
out += residual
return out