Source code for espnet2.asr.frontend.melspec_torch

#!/usr/bin/env python3
#  2023, Jee-weon Jung, CMU
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Torchaudio MFCC"""

from typing import Optional, Tuple

import torch
import torch.nn.functional as F
import torchaudio as ta
from typeguard import typechecked

from espnet2.asr.frontend.abs_frontend import AbsFrontend


[docs]class MelSpectrogramTorch(AbsFrontend): """Mel-Spectrogram using Torchaudio Implementation.""" @typechecked def __init__( self, preemp: bool = True, n_fft: int = 512, log: bool = False, win_length: int = 400, hop_length: int = 160, f_min: int = 20, f_max: int = 7600, n_mels: int = 80, window_fn: str = "hamming", mel_scale: str = "htk", normalize: Optional[str] = None, ): super().__init__() self.log = log self.n_mels = n_mels self.preemp = preemp self.normalize = normalize if window_fn == "hann": self.window_fn = torch.hann_window elif window_fn == "hamming": self.window_fn = torch.hamming_window if preemp: self.register_buffer( "flipped_filter", torch.FloatTensor([-0.97, 1.0]).unsqueeze(0).unsqueeze(0), ) self.transform = ta.transforms.MelSpectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, f_min=f_min, f_max=f_max, n_mels=n_mels, window_fn=self.window_fn, mel_scale=mel_scale, )
[docs] def forward( self, input: torch.Tensor, input_length: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input check assert ( len(input.size()) == 2 ), "The number of dimensions of input tensor must be 2!" with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): if self.preemp: # reflect padding to match lengths of in/out x = input.unsqueeze(1) x = F.pad(x, (1, 0), "reflect") # apply preemphasis x = F.conv1d(x, self.flipped_filter).squeeze(1) else: x = input # apply frame feature extraction x = self.transform(x) if self.log: x = torch.log(x + 1e-6) if self.normalize is not None: if self.normalize == "mn": x = x - torch.mean(x, dim=-1, keepdim=True) else: raise NotImplementedError( f"got {self.normalize}, not implemented" ) input_length = torch.Tensor([x.size(-1)]).repeat(x.size(0)) return x.permute(0, 2, 1), input_length
[docs] def output_size(self) -> int: """Return output length of feature dimension D.""" return self.n_mels