Source code for espnet2.enh.separator.dccrn_separator

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complexnn import (
    ComplexBatchNorm,
    ComplexConv2d,
    ComplexConvTranspose2d,
    NavieComplexLSTM,
    complex_cat,
)
from espnet2.enh.separator.abs_separator import AbsSeparator

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
EPS = torch.finfo(torch.double).eps


[docs]class DCCRNSeparator(AbsSeparator): def __init__( self, input_dim: int, num_spk: int = 1, rnn_layer: int = 2, rnn_units: int = 256, masking_mode: str = "E", use_clstm: bool = True, bidirectional: bool = False, use_cbn: bool = False, kernel_size: int = 5, kernel_num: List[int] = [32, 64, 128, 256, 256, 256], use_builtin_complex: bool = True, use_noise_mask: bool = False, ): """DCCRN separator. Args: input_dim (int): input dimension。 num_spk (int, optional): number of speakers. Defaults to 1. rnn_layer (int, optional): number of lstm layers in the crn. Defaults to 2. rnn_units (int, optional): rnn units. Defaults to 128. masking_mode (str, optional): usage of the estimated mask. Defaults to "E". use_clstm (bool, optional): whether use complex LSTM. Defaults to False. bidirectional (bool, optional): whether use BLSTM. Defaults to False. use_cbn (bool, optional): whether use complex BN. Defaults to False. kernel_size (int, optional): convolution kernel size. Defaults to 5. kernel_num (list, optional): output dimension of each layer of the encoder. use_builtin_complex (bool, optional): torch.complex if True, else ComplexTensor. use_noise_mask (bool, optional): whether to estimate the mask of noise. """ super().__init__() self.use_builtin_complex = use_builtin_complex self._num_spk = num_spk self.use_noise_mask = use_noise_mask self.predict_noise = use_noise_mask if masking_mode not in ["C", "E", "R"]: raise ValueError("Unsupported masking mode: %s" % masking_mode) # Network config self.rnn_units = rnn_units self.hidden_layers = rnn_layer self.kernel_size = kernel_size self.kernel_num = [2] + kernel_num self.masking_mode = masking_mode self.use_clstm = use_clstm fac = 2 if bidirectional else 1 self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for idx in range(len(self.kernel_num) - 1): self.encoder.append( nn.Sequential( ComplexConv2d( self.kernel_num[idx], self.kernel_num[idx + 1], kernel_size=(self.kernel_size, 2), stride=(2, 1), padding=(2, 1), ), ( nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm(self.kernel_num[idx + 1]) ), nn.PReLU(), ) ) hidden_dim = (input_dim - 1 + 2 ** (len(self.kernel_num) - 1) - 1) // ( 2 ** (len(self.kernel_num) - 1) ) hidden_dim = hidden_dim if hidden_dim > 0 else 1 if self.use_clstm: rnns = [] for idx in range(rnn_layer): rnns.append( NavieComplexLSTM( input_size=( hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units * fac ), hidden_size=self.rnn_units, bidirectional=bidirectional, batch_first=False, projection_dim=( hidden_dim * self.kernel_num[-1] if idx == rnn_layer - 1 else None ), ) ) self.enhance = nn.Sequential(*rnns) else: self.enhance = nn.LSTM( input_size=hidden_dim * self.kernel_num[-1], hidden_size=self.rnn_units, num_layers=2, dropout=0.0, bidirectional=bidirectional, batch_first=False, ) self.tranform = nn.Linear( self.rnn_units * fac, hidden_dim * self.kernel_num[-1] ) for idx in range(len(self.kernel_num) - 1, 0, -1): if idx != 1: self.decoder.append( nn.Sequential( ComplexConvTranspose2d( self.kernel_num[idx] * 2, self.kernel_num[idx - 1], kernel_size=(self.kernel_size, 2), stride=(2, 1), padding=(2, 0), output_padding=(1, 0), ), ( nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm(self.kernel_num[idx - 1]) ), nn.PReLU(), ) ) else: self.decoder.append( nn.Sequential( ComplexConvTranspose2d( self.kernel_num[idx] * 2, ( self.kernel_num[idx - 1] * (self._num_spk + 1) if self.use_noise_mask else self.kernel_num[idx - 1] * self._num_spk ), kernel_size=(self.kernel_size, 2), stride=(2, 1), padding=(2, 0), output_padding=(1, 0), ), ) ) self.flatten_parameters()
[docs] def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, F), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # shape (B, T, F) --> (B, F, T) specs = input.permute(0, 2, 1) real, imag = specs.real, specs.imag # # shape (B, F, T) # spec_mags = torch.sqrt(real**2 + imag**2 + 1e-8) # # shape (B, F, T) # spec_phase = torch.atan2(imag, real) # shape (B, 2, F, T) cspecs = torch.stack([real, imag], 1) # shape (B, 2, F-1, T) cspecs = cspecs[:, :, 1:] out = cspecs encoder_out = [] for idx, layer in enumerate(self.encoder): out = layer(out) encoder_out.append(out) # shape (B, C, F, T) batch_size, channels, dims, lengths = out.size() # shape (T, B, C, F) out = out.permute(3, 0, 1, 2) if self.use_clstm: # shape (T, B, C // 2, F) r_rnn_in = out[:, :, : channels // 2] # shape (T, B, C // 2, F) i_rnn_in = out[:, :, channels // 2 :] # shape (T, B, C // 2 * F) r_rnn_in = torch.reshape( r_rnn_in, [lengths, batch_size, channels // 2 * dims] ) # shape (T, B, C // 2 * F) i_rnn_in = torch.reshape( i_rnn_in, [lengths, batch_size, channels // 2 * dims] ) r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in]) # shape (T, B, C // 2, F) r_rnn_in = torch.reshape( r_rnn_in, [lengths, batch_size, channels // 2, dims] ) # shape (T, B, C // 2, F) i_rnn_in = torch.reshape( i_rnn_in, [lengths, batch_size, channels // 2, dims] ) # shape (T, B, C, F) out = torch.cat([r_rnn_in, i_rnn_in], 2) else: # shape (T, B, C*F) out = torch.reshape(out, [lengths, batch_size, channels * dims]) out, _ = self.enhance(out) out = self.tranform(out) # shape (T, B, C, F) out = torch.reshape(out, [lengths, batch_size, channels, dims]) # shape (B, C, F, T) out = out.permute(1, 2, 3, 0) for idx in range(len(self.decoder)): # skip connection out = complex_cat([out, encoder_out[-1 - idx]], 1) out = self.decoder[idx](out) out = out[..., 1:] # out shape = (B, 2*num_spk, F-1, T) if self.use_noise_mask == False # else (B, 2*(num_spk+1), F-1, T) masks = self.create_masks(out) masked = self.apply_masks(masks, real, imag) others = OrderedDict( zip( ["mask_spk{}".format(i + 1) for i in range(self.num_spk)], masks, ) ) if self.use_noise_mask: others["mask_noise1"] = masks[-1] others["noise1"] = masked.pop(-1) return (masked, ilens, others)
[docs] def flatten_parameters(self): if isinstance(self.enhance, nn.LSTM): self.enhance.flatten_parameters()
[docs] def create_masks(self, mask_tensor: torch.Tensor): """create estimated mask for each speaker Args: mask_tensor (torch.Tensor): output of decoder, shape(B, 2*num_spk, F-1, T) """ if self.use_noise_mask: assert mask_tensor.shape[1] == 2 * (self._num_spk + 1), mask_tensor.shape[1] else: assert mask_tensor.shape[1] == 2 * self._num_spk, mask_tensor.shape[1] masks = [] for idx in range(mask_tensor.shape[1] // 2): # shape (B, F-1, T) mask_real = mask_tensor[:, idx * 2] # shape (B, F-1, T) mask_imag = mask_tensor[:, idx * 2 + 1] # shape (B, F, T) mask_real = F.pad(mask_real, [0, 0, 1, 0]) # shape (B, F, T) mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) # mask shape (B, T, F) if is_torch_1_9_plus and self.use_builtin_complex: complex_mask = torch.complex( mask_real.permute(0, 2, 1), mask_imag.permute(0, 2, 1) ) else: complex_mask = ComplexTensor( mask_real.permute(0, 2, 1), mask_imag.permute(0, 2, 1) ) masks.append(complex_mask) return masks
[docs] def apply_masks( self, masks: List[Union[torch.Tensor, ComplexTensor]], real: torch.Tensor, imag: torch.Tensor, ): """apply masks Args: masks : est_masks, [(B, T, F), ...] real (torch.Tensor): real part of the noisy spectrum, (B, F, T) imag (torch.Tensor): imag part of the noisy spectrum, (B, F, T) Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, F), ...] """ masked = [] for i in range(len(masks)): # shape (B, T, F) --> (B, F, T) mask_real = masks[i].real.permute(0, 2, 1) mask_imag = masks[i].imag.permute(0, 2, 1) if self.masking_mode == "E": # shape (B, F, T) spec_mags = torch.sqrt(real**2 + imag**2 + 1e-8) # shape (B, F, T) spec_phase = torch.atan2(imag, real) mask_mags = (mask_real**2 + mask_imag**2) ** 0.5 # mask_mags = (mask_real ** 2 + mask_imag ** 2 + EPS) ** 0.5 real_phase = mask_real / (mask_mags + EPS) imag_phase = mask_imag / (mask_mags + EPS) # mask_phase = torch.atan2(imag_phase + EPS, real_phase + EPS) mask_phase = torch.atan2(imag_phase, real_phase) mask_mags = torch.tanh(mask_mags) est_mags = mask_mags * spec_mags est_phase = spec_phase + mask_phase real = est_mags * torch.cos(est_phase) imag = est_mags * torch.sin(est_phase) elif self.masking_mode == "C": real, imag = ( real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real, ) elif self.masking_mode == "R": real, imag = real * mask_real, imag * mask_imag # shape (B, F, T) --> (B, T, F) if is_torch_1_9_plus and self.use_builtin_complex: masked.append( torch.complex(real.permute(0, 2, 1), imag.permute(0, 2, 1)) ) else: masked.append( ComplexTensor(real.permute(0, 2, 1), imag.permute(0, 2, 1)) ) return masked
@property def num_spk(self): return self._num_spk