import math
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from espnet2.enh.decoder.stft_decoder import STFTDecoder
from espnet2.enh.encoder.stft_encoder import STFTEncoder
from espnet2.enh.layers.complex_utils import new_complex_like
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet2.torch_utils.get_layer_from_string import get_layer
[docs]class TFGridNetV2(AbsSeparator):
"""Offline TFGridNetV2. Compared with TFGridNet, TFGridNetV2 speeds up the code
by vectorizing multiple heads in self-attention, and better dealing with
Deconv1D in each intra- and inter-block when emb_ks == emb_hs.
Reference:
[1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
"TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation",
in TASLP, 2023.
[2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
"TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural
Speaker Separation", in ICASSP, 2023.
NOTES:
As outlined in the Reference, this model works best when trained with variance
normalized mixture input and target, e.g., with mixture of shape [batch, samples,
microphones], you normalize it by dividing with torch.std(mixture, (1, 2)). You
must do the same for the target signals. It is encouraged to do so when not using
scale-invariant loss functions such as SI-SDR.
Specifically, use:
std_ = std(mix)
mix = mix / std_
tgt = tgt / std_
Args:
input_dim: placeholder, not used
n_srcs: number of output sources/speakers.
n_fft: stft window size.
stride: stft stride.
window: stft window type choose between 'hamming', 'hanning' or None.
n_imics: number of microphones channels (only fixed-array geometry supported).
n_layers: number of TFGridNetV2 blocks.
lstm_hidden_units: number of hidden units in LSTM.
attn_n_head: number of heads in self-attention
attn_approx_qk_dim: approximate dimention of frame-level key and value tensors
emb_dim: embedding dimension
emb_ks: kernel size for unfolding and deconv1D
emb_hs: hop size for unfolding and deconv1D
activation: activation function to use in the whole TFGridNetV2 model,
you can use any torch supported activation e.g. 'relu' or 'elu'.
eps: small epsilon for normalization layers.
use_builtin_complex: whether to use builtin complex type or not.
"""
def __init__(
self,
input_dim,
n_srcs=2,
n_fft=128,
stride=64,
window="hann",
n_imics=1,
n_layers=6,
lstm_hidden_units=192,
attn_n_head=4,
attn_approx_qk_dim=512,
emb_dim=48,
emb_ks=4,
emb_hs=1,
activation="prelu",
eps=1.0e-5,
use_builtin_complex=False,
):
super().__init__()
self.n_srcs = n_srcs
self.n_layers = n_layers
self.n_imics = n_imics
assert n_fft % 2 == 0
n_freqs = n_fft // 2 + 1
self.enc = STFTEncoder(
n_fft, n_fft, stride, window=window, use_builtin_complex=use_builtin_complex
)
self.dec = STFTDecoder(n_fft, n_fft, stride, window=window)
t_ksize = 3
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
self.conv = nn.Sequential(
nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding),
nn.GroupNorm(1, emb_dim, eps=eps),
)
self.blocks = nn.ModuleList([])
for _ in range(n_layers):
self.blocks.append(
GridNetV2Block(
emb_dim,
emb_ks,
emb_hs,
n_freqs,
lstm_hidden_units,
n_head=attn_n_head,
approx_qk_dim=attn_approx_qk_dim,
activation=activation,
eps=eps,
)
)
self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, ks, padding=padding)
[docs] def forward(
self,
input: torch.Tensor,
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor): batched multi-channel audio tensor with
M audio channels and N samples [B, N, M]
ilens (torch.Tensor): input lengths [B]
additional (Dict or None): other data, currently unused in this model.
Returns:
enhanced (List[Union(torch.Tensor)]):
[(B, T), ...] list of len n_srcs
of mono audio tensors with T samples.
ilens (torch.Tensor): (B,)
additional (Dict or None): other data, currently unused in this model,
we return it also in output.
"""
n_samples = input.shape[1]
if self.n_imics == 1:
assert len(input.shape) == 2
input = input[..., None] # [B, N, M]
mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1]
input = input / mix_std_ # RMS normalization
batch = self.enc(input, ilens)[0] # [B, T, M, F]
batch0 = batch.transpose(1, 2) # [B, M, T, F]
batch = torch.cat((batch0.real, batch0.imag), dim=1) # [B, 2*M, T, F]
n_batch, _, n_frames, n_freqs = batch.shape
batch = self.conv(batch) # [B, -1, T, F]
for ii in range(self.n_layers):
batch = self.blocks[ii](batch) # [B, -1, T, F]
batch = self.deconv(batch) # [B, n_srcs*2, T, F]
batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs])
batch = new_complex_like(batch0, (batch[:, :, 0], batch[:, :, 1]))
batch = self.dec(batch.view(-1, n_frames, n_freqs), ilens)[0] # [B, n_srcs, -1]
batch = self.pad2(batch.view([n_batch, self.num_spk, -1]), n_samples)
batch = batch * mix_std_ # reverse the RMS normalization
batch = [batch[:, src] for src in range(self.num_spk)]
return batch, ilens, OrderedDict()
@property
def num_spk(self):
return self.n_srcs
[docs] @staticmethod
def pad2(input_tensor, target_len):
input_tensor = torch.nn.functional.pad(
input_tensor, (0, target_len - input_tensor.shape[-1])
)
return input_tensor
[docs]class GridNetV2Block(nn.Module):
def __getitem__(self, key):
return getattr(self, key)
def __init__(
self,
emb_dim,
emb_ks,
emb_hs,
n_freqs,
hidden_channels,
n_head=4,
approx_qk_dim=512,
activation="prelu",
eps=1e-5,
):
super().__init__()
assert activation == "prelu"
in_channels = emb_dim * emb_ks
self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)
self.intra_rnn = nn.LSTM(
in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
)
if emb_ks == emb_hs:
self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)
else:
self.intra_linear = nn.ConvTranspose1d(
hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
)
self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
self.inter_rnn = nn.LSTM(
in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
)
if emb_ks == emb_hs:
self.inter_linear = nn.Linear(hidden_channels * 2, in_channels)
else:
self.inter_linear = nn.ConvTranspose1d(
hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
)
E = math.ceil(
approx_qk_dim * 1.0 / n_freqs
) # approx_qk_dim is only approximate
assert emb_dim % n_head == 0
self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_Q",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_K",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module(
"attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1)
)
self.add_module(
"attn_norm_V",
AllHeadPReLULayerNormalization4DCF(
(n_head, emb_dim // n_head, n_freqs), eps=eps
),
)
self.add_module(
"attn_concat_proj",
nn.Sequential(
nn.Conv2d(emb_dim, emb_dim, 1),
get_layer(activation)(),
LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
),
)
self.emb_dim = emb_dim
self.emb_ks = emb_ks
self.emb_hs = emb_hs
self.n_head = n_head
[docs] def forward(self, x):
"""GridNetV2Block Forward.
Args:
x: [B, C, T, Q]
out: [B, C, T, Q]
"""
B, C, old_T, old_Q = x.shape
olp = self.emb_ks - self.emb_hs
T = (
math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
+ self.emb_ks
)
Q = (
math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
+ self.emb_ks
)
x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C]
x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C]
# intra RNN
input_ = x
intra_rnn = self.intra_norm(input_) # [B, T, Q, C]
if self.emb_ks == self.emb_hs:
intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C]) # [BT, Q//I, I*C]
intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, Q//I, H]
intra_rnn = self.intra_linear(intra_rnn) # [BT, Q//I, I*C]
intra_rnn = intra_rnn.view([B, T, Q, C])
else:
intra_rnn = intra_rnn.view([B * T, Q, C]) # [BT, Q, C]
intra_rnn = intra_rnn.transpose(1, 2) # [BT, C, Q]
intra_rnn = F.unfold(
intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
) # [BT, C*I, -1]
intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*I]
intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H]
intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
intra_rnn = intra_rnn.view([B, T, C, Q])
intra_rnn = intra_rnn.transpose(-2, -1) # [B, T, Q, C]
intra_rnn = intra_rnn + input_ # [B, T, Q, C]
intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C]
# inter RNN
input_ = intra_rnn
inter_rnn = self.inter_norm(input_) # [B, Q, T, C]
if self.emb_ks == self.emb_hs:
inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C]) # [BQ, T//I, I*C]
inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, T//I, H]
inter_rnn = self.inter_linear(inter_rnn) # [BQ, T//I, I*C]
inter_rnn = inter_rnn.view([B, Q, T, C])
else:
inter_rnn = inter_rnn.view(B * Q, T, C) # [BQ, T, C]
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, C, T]
inter_rnn = F.unfold(
inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
) # [BQ, C*I, -1]
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, -1, C*I]
inter_rnn, _ = self.inter_rnn(inter_rnn) # [BQ, -1, H]
inter_rnn = inter_rnn.transpose(1, 2) # [BQ, H, -1]
inter_rnn = self.inter_linear(inter_rnn) # [BQ, C, T]
inter_rnn = inter_rnn.view([B, Q, C, T])
inter_rnn = inter_rnn.transpose(-2, -1) # [B, Q, T, C]
inter_rnn = inter_rnn + input_ # [B, Q, T, C]
inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q]
inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q]
batch = inter_rnn
Q = self["attn_norm_Q"](self["attn_conv_Q"](batch)) # [B, n_head, C, T, Q]
K = self["attn_norm_K"](self["attn_conv_K"](batch)) # [B, n_head, C, T, Q]
V = self["attn_norm_V"](self["attn_conv_V"](batch)) # [B, n_head, C, T, Q]
Q = Q.view(-1, *Q.shape[2:]) # [B*n_head, C, T, Q]
K = K.view(-1, *K.shape[2:]) # [B*n_head, C, T, Q]
V = V.view(-1, *V.shape[2:]) # [B*n_head, C, T, Q]
Q = Q.transpose(1, 2)
Q = Q.flatten(start_dim=2) # [B', T, C*Q]
K = K.transpose(2, 3)
K = K.contiguous().view([B * self.n_head, -1, old_T]) # [B', C*Q, T]
V = V.transpose(1, 2) # [B', T, C, Q]
old_shape = V.shape
V = V.flatten(start_dim=2) # [B', T, C*Q]
emb_dim = Q.shape[-1]
attn_mat = torch.matmul(Q, K) / (emb_dim**0.5) # [B', T, T]
attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
V = torch.matmul(attn_mat, V) # [B', T, C*Q]
V = V.reshape(old_shape) # [B', T, C, Q]
V = V.transpose(1, 2) # [B', C, T, Q]
emb_dim = V.shape[1]
batch = V.contiguous().view(
[B, self.n_head * emb_dim, old_T, old_Q]
) # [B, C, T, Q])
batch = self["attn_concat_proj"](batch) # [B, C, T, Q])
out = batch + inter_rnn
return out
[docs]class LayerNormalization4DCF(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
assert len(input_dimension) == 2
param_size = [1, input_dimension[0], 1, input_dimension[1]]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.eps = eps
[docs] def forward(self, x):
if x.ndim == 4:
stat_dim = (1, 3)
else:
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
) # [B,1,T,F]
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
return x_hat
[docs]class AllHeadPReLULayerNormalization4DCF(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
assert len(input_dimension) == 3
H, E, n_freqs = input_dimension
param_size = [1, H, E, 1, n_freqs]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.act = nn.PReLU(num_parameters=H, init=0.25)
self.eps = eps
self.H = H
self.E = E
self.n_freqs = n_freqs
[docs] def forward(self, x):
assert x.ndim == 4
B, _, T, _ = x.shape
x = x.view([B, self.H, self.E, T, self.n_freqs])
x = self.act(x) # [B,H,E,T,F]
stat_dim = (2, 4)
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,H,1,T,1]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
) # [B,H,1,T,1]
x = ((x - mu_) / std_) * self.gamma + self.beta # [B,H,E,T,F]
return x