Source code for espnet.nets.pytorch_backend.frontends.dnn_wpe

from typing import Tuple

import torch
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.wpe import wpe_one_iteration
from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask


[docs]class DNN_WPE(torch.nn.Module): def __init__( self, wtype: str = "blstmp", widim: int = 257, wlayers: int = 3, wunits: int = 300, wprojs: int = 320, dropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask: bool = True, iterations: int = 1, normalization: bool = False, ): super().__init__() self.iterations = iterations self.taps = taps self.delay = delay self.normalization = normalization self.use_dnn_mask = use_dnn_mask self.inverse_power = True if self.use_dnn_mask: self.mask_est = MaskEstimator( wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1 )
[docs] def forward( self, data: ComplexTensor, ilens: torch.LongTensor ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: """The forward function Notation: B: Batch C: Channel T: Time or Sequence length F: Freq or Some dimension of the feature vector Args: data: (B, C, T, F) ilens: (B,) Returns: data: (B, C, T, F) ilens: (B,) """ # (B, T, C, F) -> (B, F, C, T) enhanced = data = data.permute(0, 3, 2, 1) mask = None for i in range(self.iterations): # Calculate power: (..., C, T) power = enhanced.real**2 + enhanced.imag**2 if i == 0 and self.use_dnn_mask: # mask: (B, F, C, T) (mask,), _ = self.mask_est(enhanced, ilens) if self.normalization: # Normalize along T mask = mask / mask.sum(dim=-1)[..., None] # (..., C, T) * (..., C, T) -> (..., C, T) power = power * mask # Averaging along the channel axis: (..., C, T) -> (..., T) power = power.mean(dim=-2) # enhanced: (..., C, T) -> (..., C, T) enhanced = wpe_one_iteration( data.contiguous(), power, taps=self.taps, delay=self.delay, inverse_power=self.inverse_power, ) enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) # (B, F, C, T) -> (B, T, C, F) enhanced = enhanced.permute(0, 3, 2, 1) if mask is not None: mask = mask.transpose(-1, -3) return enhanced, ilens, mask