Source code for espnet2.enh.encoder.stft_encoder

import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor

from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.layers.stft import Stft

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


[docs]class STFTEncoder(AbsEncoder): """STFT encoder for speech enhancement and separation""" def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, use_builtin_complex: bool = True, default_fs: int = 16000, spec_transform_type: str = None, spec_factor: float = 0.15, spec_abs_exponent: float = 0.5, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) self._output_dim = n_fft // 2 + 1 if onesided else n_fft self.use_builtin_complex = use_builtin_complex self.win_length = win_length if win_length else n_fft self.hop_length = hop_length self.window = window self.n_fft = n_fft self.center = center self.default_fs = default_fs # spec transform related. See equation (1) in paper # 'Speech Enhancement and Dereverberation With Diffusion-Based Generative # Models'. The default value of 0.15, 0.5 also come from the paper. # spec_transform_type: "exponent", "log", or "none" self.spec_transform_type = spec_transform_type # the output specturm will be scaled with: spec * self.spec_factor self.spec_factor = spec_factor # the exponent factor used in the "exponent" transform self.spec_abs_exponent = spec_abs_exponent
[docs] def spec_transform_func(self, spec): if self.spec_transform_type == "exponent": if self.spec_abs_exponent != 1: # only do this calculation if spec_exponent != 1, # otherwise it's quite a bit of wasted computation # and introduced numerical error e = self.spec_abs_exponent spec = spec.abs() ** e * torch.exp(1j * spec.angle()) spec = spec * self.spec_factor elif self.spec_transform_type == "log": spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle()) spec = spec * self.spec_factor elif self.spec_transform_type == "none": spec = spec return spec
@property def output_dim(self) -> int: return self._output_dim
[docs] @torch.cuda.amp.autocast(enabled=False) def forward(self, input: torch.Tensor, ilens: torch.Tensor, fs: int = None): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] fs (int): sampling rate in Hz If not None, reconfigure STFT window and hop lengths for a new sampling rate while keeping their duration fixed. Returns: spectrum (ComplexTensor): [Batch, T, (C,) F] flens (torch.Tensor): [Batch] """ if fs is not None: self._reconfig_for_fs(fs) # for supporting half-precision training if input.dtype in (torch.float16, torch.bfloat16): spectrum, flens = self.stft(input.float(), ilens) spectrum = spectrum.to(dtype=input.dtype) else: spectrum, flens = self.stft(input, ilens) if is_torch_1_9_plus and self.use_builtin_complex: spectrum = torch.complex(spectrum[..., 0], spectrum[..., 1]) else: spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1]) self._reset_config() spectrum = self.spec_transform_func(spectrum) return spectrum, flens
def _reset_config(self): """Reset the configuration of STFT window and hop lengths.""" self._reconfig_for_fs(self.default_fs) def _reconfig_for_fs(self, fs): """Reconfigure STFT window and hop lengths for a new sampling rate while keeping their duration fixed. Args: fs (int): new sampling rate """ # noqa: H405 self.stft.n_fft = self.n_fft * fs // self.default_fs self.stft.win_length = self.win_length * fs // self.default_fs self.stft.hop_length = self.hop_length * fs // self.default_fs def _apply_window_func(self, input): B = input.shape[0] window_func = getattr(torch, f"{self.window}_window") window = window_func(self.win_length, dtype=input.dtype, device=input.device) n_pad_left = (self.n_fft - window.shape[0]) // 2 n_pad_right = self.n_fft - window.shape[0] - n_pad_left windowed = input * window windowed = torch.cat( [torch.zeros(B, n_pad_left), windowed, torch.zeros(B, n_pad_right)], 1 ) return windowed
[docs] def forward_streaming(self, input: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, frame_length] Return: B, 1, F """ assert ( input.dim() == 2 ), "forward_streaming only support for single-channel input currently." windowed = self._apply_window_func(input) feature = ( torch.fft.rfft(windowed) if self.stft.onesided else torch.fft.fft(windowed) ) feature = feature.unsqueeze(1) if not (is_torch_1_9_plus and self.use_builtin_complex): feature = ComplexTensor(feature.real, feature.imag) feature = self.spec_transform_func(feature) return feature
[docs] def streaming_frame(self, audio): """streaming_frame. It splits the continuous audio into frame-level audio chunks in the streaming *simulation*. It is noted that this function takes the entire long audio as input for a streaming simulation. You may refer to this function to manage your streaming input buffer in a real streaming application. Args: audio: (B, T) Returns: chunked: List [(B, frame_size),] """ # noqa: H405 if self.center: pad_len = int(self.win_length // 2) signal_dim = audio.dim() extended_shape = [1] * (3 - signal_dim) + list(audio.size()) # the default STFT pad mode is "reflect", # which is not configurable in STFT encoder, # so, here we just use "reflect mode" audio = torch.nn.functional.pad( audio.view(extended_shape), [pad_len, pad_len], "reflect" ) audio = audio.view(audio.shape[-signal_dim:]) _, audio_len = audio.shape n_frames = 1 + (audio_len - self.win_length) // self.hop_length strides = list(audio.stride()) shape = list(audio.shape[:-1]) + [self.win_length, n_frames] strides = strides + [self.hop_length] return audio.as_strided(shape, strides, storage_offset=0).unbind(dim=-1)