Source code for espnet2.enh.layers.uses

import warnings

import torch
import torch.nn as nn

from espnet2.enh.layers.dptnet import ImprovedTransformerLayer as SingleTransformer
from espnet2.enh.layers.tcn import ChannelwiseLayerNorm
from espnet2.torch_utils.get_layer_from_string import get_layer


[docs]class USES(nn.Module): """Unconstrained Speech Enhancement and Separation (USES) Network. Reference: [1] W. Zhang, K. Saijo, Z.-Q., Wang, S. Watanabe, and Y. Qian, “Toward Universal Speech Enhancement for Diverse Input Conditions,” in Proc. ASRU, 2023. args: input_size (int): dimension of the input feature. output_size (int): dimension of the output. bottleneck_size (int): dimension of the bottleneck feature. Must be a multiple of `att_heads`. num_blocks (int): number of processing blocks. num_spatial_blocks (int): number of processing blocks with channel modeling. segment_size (int): number of frames in each non-overlapping segment. This is used to segment long utterances into smaller segments for efficient processing. memory_size (int): group size of global memory tokens. The basic use of memory tokens is to store the history information from previous segments. The memory tokens are updated by the output of the last block after processing each segment. memory_types (int): numbre of memory token groups. Each group corresponds to a different type of processing, i.e., the first group is used for denoising without dereverberation, the second group is used for denoising with dereverberation, rnn_type (str): type of the RNN cell in the improved Transformer layer. hidden_size (int): hidden dimension of the RNN cell. att_heads (int): number of attention heads in Transformer. dropout (float): dropout ratio. Default is 0. activation (str): non-linear activation function applied in each block. bidirectional (bool): whether the RNN layers are bidirectional. norm_type (str): normalization type in the improved Transformer layer. ch_mode (str): mode of channel modeling. Select from "att" and "tac". ch_att_dim (int): dimension of the channel attention. eps (float): epsilon for layer normalization. """ def __init__( self, input_size, output_size, bottleneck_size=64, num_blocks=6, num_spatial_blocks=3, segment_size=64, memory_size=20, memory_types=1, # Transformer-related arguments rnn_type="lstm", hidden_size=128, att_heads=4, dropout=0.0, activation="relu", bidirectional=True, norm_type="cLN", ch_mode="att", ch_att_dim=256, eps=1e-5, ): super().__init__() # [B, input_size, T] -> [B, input_size, T] self.layer_norm = ChannelwiseLayerNorm(input_size) # [B, input_size, T] -> [B, bottleneck_size, T] self.bottleneck_conv1x1 = nn.Conv1d(input_size, bottleneck_size, 1, bias=False) self.input_size = input_size self.bottleneck_size = bottleneck_size self.output_size = output_size assert num_blocks >= num_spatial_blocks, (num_blocks, num_spatial_blocks) assert ch_mode in ("att", "tac"), ch_mode self.atf_blocks = nn.ModuleList() for i in range(num_blocks): self.atf_blocks.append( ATFBlock( input_size=bottleneck_size, rnn_type=rnn_type, hidden_size=hidden_size, att_heads=att_heads, dropout=dropout, activation=activation, bidirectional=bidirectional, norm_type=norm_type, ch_mode=ch_mode, ch_att_dim=ch_att_dim, eps=eps, with_channel_modeling=i < num_spatial_blocks, ) ) self.segment_size = segment_size self.memory_size = memory_size self.memory_types = memory_types if memory_types == 1: # single group of memory tokens (only used to provide history information) # (B=1, C=1, bottleneck_size, F=1, memory_size) self.memory_tokens = nn.Parameter( torch.randn(1, 1, bottleneck_size, 1, memory_size) ) else: # >1 groups of memory tokens (used to also control processing behavior) self.memory_tokens = nn.ParameterList( [ nn.Parameter(torch.randn(1, 1, bottleneck_size, 1, memory_size)) for _ in range(memory_types) ] ) self.output = nn.Sequential( nn.PReLU(), nn.Conv2d(bottleneck_size, output_size, 1) )
[docs] def forward(self, input, ref_channel=None, mem_idx=None): """USES forward. Args: input (torch.Tensor): input feature (batch, mics, input_size, freq, time) ref_channel (None or int): index of the reference channel. if None, simply average all channels. if int, take the specified channel instead of averaging. mem_idx (None or int): index of the memory token group. if None, use the only group of memory tokens in the model. if int, use the specified group from multiple existing groups. Returns: output (torch.Tensor): output feature (batch, output_size, freq, time) """ B, C, N, F, T = input.shape output = self.layer_norm(input.reshape(B * C, N, -1)) # B, C, bn, F, T output = self.bottleneck_conv1x1(output).reshape(B, C, -1, F, T) bn = output.size(2) # Divide the input into non-overlapping segments num_seg, res = divmod(T, self.segment_size) if res > 0: # pad the last segment if necessary output = nn.functional.pad(output, (0, self.segment_size - res)) num_seg += 1 if self.training and num_seg < 2: warnings.warn(f"The input is too short for training: {T}") output = output.reshape(B, C, bn, F, num_seg, self.segment_size) # Segment-by-segment processing for memory-efficient processing ret = [] mem = None for n in range(num_seg): out = output[..., n, :] if mem is None: # initialize memory tokens for the first segment if mem_idx is not None: mem = self.memory_tokens[mem_idx].repeat(B, C, 1, F, 1) elif self.memory_types > 1: mem = self.memory_tokens[0].repeat(B, C, 1, F, 1) else: mem = self.memory_tokens.repeat(B, C, 1, F, 1) out = torch.cat([mem, out], dim=-1) else: # reuse memory tokens from the last segment if mem.size(1) < C: mem = mem.repeat(1, C // mem.size(1), 1, 1, 1) out = torch.cat([mem, out], dim=-1) for block in self.atf_blocks: out = block(out, ref_channel=ref_channel) mem, out = out[..., : self.memory_size], out[..., self.memory_size :] ret.append(out) output = torch.cat(ret, dim=-1)[..., :T] with torch.cuda.amp.autocast(enabled=False): output = self.output(output.mean(1)) # B, output_size, F, T return output
[docs]class ATFBlock(nn.Module): def __init__( self, input_size, rnn_type="lstm", hidden_size=128, att_heads=4, dropout=0.0, activation="relu", bidirectional=True, norm_type="cLN", ch_mode="att", ch_att_dim=256, eps=1e-5, with_channel_modeling=True, ): """Container module for a single Attentive Time-Frequency Block. Args: input_size (int): dimension of the input feature. rnn_type (str): type of the RNN cell in the improved Transformer layer. hidden_size (int): hidden dimension of the RNN cell. att_heads (int): number of attention heads in Transformer. dropout (float): dropout ratio. Default is 0. activation (str): non-linear activation function applied in each block. bidirectional (bool): whether the RNN layers are bidirectional. norm_type (str): normalization type in the improved Transformer layer. ch_mode (str): mode of channel modeling. Select from "att" and "tac". ch_att_dim (int): dimension of the channel attention. eps (float): epsilon for layer normalization. with_channel_modeling (bool): whether to use channel modeling. """ super().__init__() kwargs = dict( rnn_type=rnn_type, input_size=input_size, att_heads=att_heads, hidden_size=hidden_size, dropout=dropout, activation="linear", bidirectional=bidirectional, norm=norm_type, ) self.freq_nn = SingleTransformer(**kwargs) self.temporal_nn = SingleTransformer(**kwargs) self.with_channel_modeling = with_channel_modeling self.ch_mode = ch_mode if with_channel_modeling: if ch_mode == "att": self.channel_nn = ChannelAttention( input_dim=input_size, att_heads=att_heads, att_dim=ch_att_dim, activation=activation, eps=eps, ) elif ch_mode == "tac": self.channel_nn = ChannelTAC(input_dim=input_size, eps=eps) else: raise NotImplementedError(ch_mode)
[docs] def forward(self, input, ref_channel=None): """Forward. Args: input (torch.Tensor): feature sequence (batch, C, N, freq, time) ref_channel (None or int): index of the reference channel. if None, simply average all channels. if int, take the specified channel instead of averaging. Returns: output (torch.Tensor): output sequence (batch, C, N, freq, time) """ if not self.with_channel_modeling: if input.size(1) > 1 and ref_channel is not None: input = input[:, ref_channel].unsqueeze(1) else: input = input.mean(dim=1, keepdim=True) B, C, N, F, T = input.shape output = input.reshape(B * C, N, F, T).contiguous() output = self.freq_path_process(output) output = self.time_path_process(output) output = output.contiguous().reshape(B, C, N, F, T) if self.with_channel_modeling: output = self.channel_nn(output, ref_channel=ref_channel) return output
[docs] def freq_path_process(self, x): batch, N, freq, time = x.shape x = x.permute(0, 3, 2, 1).reshape(batch * time, freq, N) x = self.freq_nn(x) x = x.reshape(batch, time, freq, N).permute(0, 3, 2, 1) return x.contiguous()
[docs] def time_path_process(self, x): batch, N, freq, time = x.shape x = x.permute(0, 2, 3, 1).reshape(batch * freq, time, N) x = self.temporal_nn(x) x = x.reshape(batch, freq, time, N).permute(0, 3, 1, 2) return x.contiguous()
[docs]class ChannelAttention(nn.Module): def __init__( self, input_dim, att_heads=4, att_dim=256, activation="relu", eps=1e-5 ): """Channel Attention module. Args: input_dim (int): dimension of the input feature. att_heads (int): number of attention heads in self-attention. att_dim (int): projection dimension for query and key before self-attention. activation (str): non-linear activation function. eps (float): epsilon for layer normalization. """ super().__init__() self.att_heads = att_heads self.att_dim = att_dim self.activation = activation assert input_dim % att_heads == 0, (input_dim, att_heads) self.attn_conv_Q = nn.Sequential( nn.Linear(input_dim, att_dim), get_layer(activation)(), LayerNormalization(att_dim, dim=-1, total_dim=5, eps=eps), ) self.attn_conv_K = nn.Sequential( nn.Linear(input_dim, att_dim), get_layer(activation)(), LayerNormalization(att_dim, dim=-1, total_dim=5, eps=eps), ) self.attn_conv_V = nn.Sequential( nn.Linear(input_dim, input_dim), get_layer(activation)(), LayerNormalization(input_dim, dim=-1, total_dim=5, eps=eps), ) self.attn_concat_proj = nn.Sequential( nn.Linear(input_dim, input_dim), get_layer(activation)(), LayerNormalization(input_dim, dim=-1, total_dim=5, eps=eps), ) def __getitem__(self, key): return getattr(self, key)
[docs] def forward(self, x, ref_channel=None): """ChannelAttention Forward. Args: x (torch.Tensor): input feature (batch, C, N, freq, time) ref_channel (None or int): index of the reference channel. Returns: output (torch.Tensor): output feature (batch, C, N, freq, time) """ B, C, N, F, T = x.shape batch = x.permute(0, 4, 1, 3, 2) # [B, T, C, F, N] Q = ( self.attn_conv_Q(batch) .reshape(B, T, C, F, -1, self.att_heads) .permute(0, 5, 1, 2, 3, 4) .contiguous() ) # [B, head, T, C, F, D] K = ( self.attn_conv_K(batch) .reshape(B, T, C, F, -1, self.att_heads) .permute(0, 5, 1, 2, 3, 4) .contiguous() ) # [B, head, T, C, F, D] V = ( self.attn_conv_V(batch) .reshape(B, T, C, F, -1, self.att_heads) .permute(0, 5, 1, 2, 3, 4) .contiguous() ) # [B, head, T, C, F, D'] emb_dim = V.size(-2) * V.size(-1) attn_mat = torch.einsum("bhtcfn,bhtefn->bhce", Q / T, K / emb_dim**0.5) attn_mat = nn.functional.softmax(attn_mat, dim=-1) # [B, head, C, C] V = torch.einsum("bhce,bhtefn->bhtcfn", attn_mat, V) # [B, head, T, C, F, D'] batch = torch.cat(V.unbind(dim=1), dim=-1) # [B, T, C, F, D] batch = self["attn_concat_proj"](batch) # [B, T, C, F, N] return batch.permute(0, 2, 4, 3, 1) + x
[docs]class ChannelTAC(nn.Module): def __init__(self, input_dim, eps=1e-5): """Channel Transform-Average-Concatenate (TAC) module. Args: input_dim (int): dimension of the input feature. eps (float): epsilon for layer normalization. """ super().__init__() hidden_dim = input_dim * 3 self.transform = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.PReLU()) self.average = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.PReLU()) self.concat = nn.Sequential( nn.Linear(hidden_dim * 2, input_dim), nn.PReLU(), LayerNormalization(input_dim, dim=-1, total_dim=5, eps=eps), )
[docs] @torch.cuda.amp.autocast(enabled=False) def forward(self, x, ref_channel=None): """ChannelTAC Forward. Args: x (torch.Tensor): input feature (batch, C, N, freq, time) ref_channel (None or int): index of the reference channel. Returns: output (torch.Tensor): output feature (batch, C, N, freq, time) """ batch = x.contiguous().permute(0, 4, 1, 3, 2) # [B, T, C, F, N] out = self.transform(batch) out_mean = self.average(out.mean(dim=2, keepdim=True)).expand_as(out) out = self.concat(torch.cat([out, out_mean], dim=-1)) out = out.permute(0, 2, 4, 3, 1) + x return out
[docs]class LayerNormalization(nn.Module): def __init__(self, input_dim, dim=1, total_dim=4, eps=1e-5): super().__init__() self.dim = dim if dim >= 0 else total_dim + dim param_size = [1 if ii != self.dim else input_dim for ii in range(total_dim)] self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32)) self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32)) nn.init.ones_(self.gamma) nn.init.zeros_(self.beta) self.eps = eps
[docs] @torch.cuda.amp.autocast(enabled=False) def forward(self, x): if x.ndim - 1 < self.dim: raise ValueError( f"Expect x to have {self.dim + 1} dimensions, but got {x.ndim}" ) mu_ = x.mean(dim=self.dim, keepdim=True) std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps) x_hat = ((x - mu_) / std_) * self.gamma + self.beta return x_hat