Source code for espnet2.enh.layers.dprnn

# The implementation of DPRNN in
# Luo. et al. "Dual-path rnn: efficient long sequence modeling
# for time-domain single-channel speech separation."
#
# The code is based on:
# https://github.com/yluo42/TAC/blob/master/utility/models.py
# Licensed under CC BY-NC-SA 3.0 US.
#


import torch
import torch.nn as nn
from torch.autograd import Variable

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class SingleRNN(nn.Module): """Container module for a single RNN layer. args: rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. dropout: float, dropout ratio. Default is 0. bidirectional: bool, whether the RNN layers are bidirectional. Default is False. """ def __init__( self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False ): super().__init__() rnn_type = rnn_type.upper() assert rnn_type in [ "RNN", "LSTM", "GRU", ], f"Only support 'RNN', 'LSTM' and 'GRU', current type: {rnn_type}" self.rnn_type = rnn_type self.input_size = input_size self.hidden_size = hidden_size self.num_direction = int(bidirectional) + 1 self.rnn = getattr(nn, rnn_type)( input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional, ) self.dropout = nn.Dropout(p=dropout) # linear projection layer self.proj = nn.Linear(hidden_size * self.num_direction, input_size)
[docs] def forward(self, input, state=None): # input shape: batch, seq, dim # input = input.to(device) output = input rnn_output, state = self.rnn(output, state) rnn_output = self.dropout(rnn_output) rnn_output = self.proj( rnn_output.contiguous().view(-1, rnn_output.shape[2]) ).view(output.shape) return rnn_output, state
# dual-path RNN
[docs]class DPRNN(nn.Module): """Deep dual-path RNN. args: rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. output_size: int, dimension of the output size. dropout: float, dropout ratio. Default is 0. num_layers: int, number of stacked RNN layers. Default is 1. bidirectional: bool, whether the RNN layers are bidirectional. Default is True. """ def __init__( self, rnn_type, input_size, hidden_size, output_size, dropout=0, num_layers=1, bidirectional=True, ): super().__init__() self.input_size = input_size self.output_size = output_size self.hidden_size = hidden_size # dual-path RNN self.row_rnn = nn.ModuleList([]) self.col_rnn = nn.ModuleList([]) self.row_norm = nn.ModuleList([]) self.col_norm = nn.ModuleList([]) for i in range(num_layers): self.row_rnn.append( SingleRNN( rnn_type, input_size, hidden_size, dropout, bidirectional=True ) ) # intra-segment RNN is always noncausal self.col_rnn.append( SingleRNN( rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional, ) ) self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) # default is to use noncausal LayerNorm for inter-chunk RNN. # For causal setting change it to causal normalization accordingly. self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) # output layer self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
[docs] def forward(self, input): # input shape: batch, N, dim1, dim2 # apply RNN on dim1 first and then dim2 # output shape: B, output_size, dim1, dim2 # input = input.to(device) batch_size, _, dim1, dim2 = input.shape output = input for i in range(len(self.row_rnn)): row_input = ( output.permute(0, 3, 2, 1) .contiguous() .view(batch_size * dim2, dim1, -1) ) # B*dim2, dim1, N row_output, _ = self.row_rnn[i](row_input) # B*dim2, dim1, H row_output = ( row_output.view(batch_size, dim2, dim1, -1) .permute(0, 3, 2, 1) .contiguous() ) # B, N, dim1, dim2 row_output = self.row_norm[i](row_output) output = output + row_output col_input = ( output.permute(0, 2, 3, 1) .contiguous() .view(batch_size * dim1, dim2, -1) ) # B*dim1, dim2, N col_output, _ = self.col_rnn[i](col_input) # B*dim1, dim2, H col_output = ( col_output.view(batch_size, dim1, dim2, -1) .permute(0, 3, 1, 2) .contiguous() ) # B, N, dim1, dim2 col_output = self.col_norm[i](col_output) output = output + col_output output = self.output(output) # B, output_size, dim1, dim2 return output
# dual-path RNN with transform-average-concatenate (TAC)
[docs]class DPRNN_TAC(nn.Module): """Deep duaL-path RNN with TAC applied to each layer/block. args: rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. output_size: int, dimension of the output size. dropout: float, dropout ratio. Default is 0. num_layers: int, number of stacked RNN layers. Default is 1. bidirectional: bool, whether the RNN layers are bidirectional. Default is False. """ def __init__( self, rnn_type, input_size, hidden_size, output_size, dropout=0, num_layers=1, bidirectional=True, ): super(DPRNN_TAC, self).__init__() self.input_size = input_size self.output_size = output_size self.hidden_size = hidden_size # DPRNN + TAC for 3D input (ch, N, T) self.row_rnn = nn.ModuleList([]) self.col_rnn = nn.ModuleList([]) self.ch_transform = nn.ModuleList([]) self.ch_average = nn.ModuleList([]) self.ch_concat = nn.ModuleList([]) self.row_norm = nn.ModuleList([]) self.col_norm = nn.ModuleList([]) self.ch_norm = nn.ModuleList([]) for i in range(num_layers): self.row_rnn.append( SingleRNN( rnn_type, input_size, hidden_size, dropout, bidirectional=True ) ) # intra-segment RNN is always noncausal self.col_rnn.append( SingleRNN( rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional, ) ) self.ch_transform.append( nn.Sequential(nn.Linear(input_size, hidden_size * 3), nn.PReLU()) ) self.ch_average.append( nn.Sequential(nn.Linear(hidden_size * 3, hidden_size * 3), nn.PReLU()) ) self.ch_concat.append( nn.Sequential(nn.Linear(hidden_size * 6, input_size), nn.PReLU()) ) self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) # default is to use noncausal LayerNorm for # inter-chunk RNN and TAC modules. # For causal setting change them to causal normalization # techniques accordingly. self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) self.ch_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) # output layer self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
[docs] def forward(self, input, num_mic): # input shape: batch, ch, N, dim1, dim2 # num_mic shape: batch, # apply RNN on dim1 first, then dim2, then ch batch_size, ch, N, dim1, dim2 = input.shape output = input for i in range(len(self.row_rnn)): # intra-segment RNN output = output.view(batch_size * ch, N, dim1, dim2) row_input = ( output.permute(0, 3, 2, 1) .contiguous() .view(batch_size * ch * dim2, dim1, -1) ) # B*ch*dim2, dim1, N row_output, _ = self.row_rnn[i](row_input) # B*ch*dim2, dim1, N row_output = ( row_output.view(batch_size * ch, dim2, dim1, -1) .permute(0, 3, 2, 1) .contiguous() ) # B*ch, N, dim1, dim2 row_output = self.row_norm[i](row_output) output = output + row_output # B*ch, N, dim1, dim2 # inter-segment RNN col_input = ( output.permute(0, 2, 3, 1) .contiguous() .view(batch_size * ch * dim1, dim2, -1) ) # B*ch*dim1, dim2, N col_output, _ = self.col_rnn[i](col_input) # B*dim1, dim2, N col_output = ( col_output.view(batch_size * ch, dim1, dim2, -1) .permute(0, 3, 1, 2) .contiguous() ) # B*ch, N, dim1, dim2 col_output = self.col_norm[i](col_output) output = output + col_output # B*ch, N, dim1, dim2 # TAC for cross-channel communication ch_input = output.view(input.shape) # B, ch, N, dim1, dim2 ch_input = ( ch_input.permute(0, 3, 4, 1, 2).contiguous().view(-1, N) ) # B*dim1*dim2*ch, N ch_output = self.ch_transform[i](ch_input).view( batch_size, dim1 * dim2, ch, -1 ) # B, dim1*dim2, ch, H # mean pooling across channels if num_mic.max() == 0: # fixed geometry array ch_mean = ch_output.mean(2).view( batch_size * dim1 * dim2, -1 ) # B*dim1*dim2, H else: # only consider valid channels ch_mean = [ ch_output[b, :, : num_mic[b]].mean(1).unsqueeze(0) for b in range(batch_size) ] # 1, dim1*dim2, H ch_mean = torch.cat(ch_mean, 0).view( batch_size * dim1 * dim2, -1 ) # B*dim1*dim2, H ch_output = ch_output.view( batch_size * dim1 * dim2, ch, -1 ) # B*dim1*dim2, ch, H ch_mean = ( self.ch_average[i](ch_mean) .unsqueeze(1) .expand_as(ch_output) .contiguous() ) # B*dim1*dim2, ch, H ch_output = torch.cat([ch_output, ch_mean], 2) # B*dim1*dim2, ch, 2H ch_output = self.ch_concat[i]( ch_output.view(-1, ch_output.shape[-1]) ) # B*dim1*dim2*ch, N ch_output = ( ch_output.view(batch_size, dim1, dim2, ch, -1) .permute(0, 3, 4, 1, 2) .contiguous() ) # B, ch, N, dim1, dim2 ch_output = self.ch_norm[i]( ch_output.view(batch_size * ch, N, dim1, dim2) ) # B*ch, N, dim1, dim2 output = output + ch_output output = self.output(output) # B*ch, N, dim1, dim2 return output
def _pad_segment(input, segment_size): # input is the features: (B, N, T) batch_size, dim, seq_len = input.shape segment_stride = segment_size // 2 rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size if rest > 0: pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) input = torch.cat([input, pad], 2) pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) input = torch.cat([pad_aux, input, pad_aux], 2) return input, rest
[docs]def split_feature(input, segment_size): # split the feature into chunks of segment size # input is the features: (B, N, T) input, rest = _pad_segment(input, segment_size) batch_size, dim, seq_len = input.shape segment_stride = segment_size // 2 segments1 = ( input[:, :, :-segment_stride] .contiguous() .view(batch_size, dim, -1, segment_size) ) segments2 = ( input[:, :, segment_stride:] .contiguous() .view(batch_size, dim, -1, segment_size) ) segments = ( torch.cat([segments1, segments2], 3) .view(batch_size, dim, -1, segment_size) .transpose(2, 3) ) return segments.contiguous(), rest
[docs]def merge_feature(input, rest): # merge the splitted features into full utterance # input is the features: (B, N, L, K) batch_size, dim, segment_size, _ = input.shape segment_stride = segment_size // 2 input = ( input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2) ) # B, N, K, L input1 = ( input[:, :, :, :segment_size] .contiguous() .view(batch_size, dim, -1)[:, :, segment_stride:] ) input2 = ( input[:, :, :, segment_size:] .contiguous() .view(batch_size, dim, -1)[:, :, :-segment_stride] ) output = input1 + input2 if rest > 0: output = output[:, :, :-rest] return output.contiguous() # B, N, T