Source code for espnet.nets.pytorch_backend.frontends.feature_transform

from typing import List, Tuple, Union

import librosa
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor

from espnet.nets.pytorch_backend.nets_utils import make_pad_mask


[docs]class FeatureTransform(torch.nn.Module): def __init__( self, # Mel options, fs: int = 16000, n_fft: int = 512, n_mels: int = 80, fmin: float = 0.0, fmax: float = None, # Normalization stats_file: str = None, apply_uttmvn: bool = True, uttmvn_norm_means: bool = True, uttmvn_norm_vars: bool = False, ): super().__init__() self.apply_uttmvn = apply_uttmvn self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) self.stats_file = stats_file if stats_file is not None: self.global_mvn = GlobalMVN(stats_file) else: self.global_mvn = None if self.apply_uttmvn is not None: self.uttmvn = UtteranceMVN( norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars ) else: self.uttmvn = None
[docs] def forward( self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] ) -> Tuple[torch.Tensor, torch.LongTensor]: # (B, T, F) or (B, T, C, F) if x.dim() not in (3, 4): raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") if not torch.is_tensor(ilens): ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) if x.dim() == 4: # h: (B, T, C, F) -> h: (B, T, F) if self.training: # Select 1ch randomly ch = np.random.randint(x.size(2)) h = x[:, :, ch, :] else: # Use the first channel h = x[:, :, 0, :] else: h = x # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) h = h.real**2 + h.imag**2 h, _ = self.logmel(h, ilens) if self.stats_file is not None: h, _ = self.global_mvn(h, ilens) if self.apply_uttmvn: h, _ = self.uttmvn(h, ilens) return h, ilens
[docs]class LogMel(torch.nn.Module): """Convert STFT to fbank feats The arguments is same as librosa.filters.mel Args: fs: number > 0 [scalar] sampling rate of the incoming signal n_fft: int > 0 [scalar] number of FFT components n_mels: int > 0 [scalar] number of Mel bands to generate fmin: float >= 0 [scalar] lowest frequency (in Hz) fmax: float >= 0 [scalar] highest frequency (in Hz). If `None`, use `fmax = fs / 2.0` htk: use HTK formula instead of Slaney norm: {None, 1, np.inf} [scalar] if 1, divide the triangular mel weights by the width of the mel band (area normalization). Otherwise, leave all the triangles aiming for a peak value of 1.0 """ def __init__( self, fs: int = 16000, n_fft: int = 512, n_mels: int = 80, fmin: float = 0.0, fmax: float = None, htk: bool = False, norm=1, ): super().__init__() _mel_options = dict( sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm ) self.mel_options = _mel_options # Note(kamo): The mel matrix of librosa is different from kaldi. melmat = librosa.filters.mel(**_mel_options) # melmat: (D2, D1) -> (D1, D2) self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
[docs] def extra_repr(self): return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
[docs] def forward( self, feat: torch.Tensor, ilens: torch.LongTensor ) -> Tuple[torch.Tensor, torch.LongTensor]: # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) mel_feat = torch.matmul(feat, self.melmat) logmel_feat = (mel_feat + 1e-20).log() # Zero padding logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) return logmel_feat, ilens
[docs]class GlobalMVN(torch.nn.Module): """Apply global mean and variance normalization Args: stats_file(str): npy file of 1-dim array or text file. From the _first element to the {(len(array) - 1) / 2}th element are treated as the sum of features, and the rest excluding the last elements are treated as the sum of the square value of features, and the last elements eqauls to the number of samples. std_floor(float): """ def __init__( self, stats_file: str, norm_means: bool = True, norm_vars: bool = True, eps: float = 1.0e-20, ): super().__init__() self.norm_means = norm_means self.norm_vars = norm_vars self.stats_file = stats_file stats = np.load(stats_file) stats = stats.astype(float) assert (len(stats) - 1) % 2 == 0, stats.shape count = stats.flatten()[-1] mean = stats[: (len(stats) - 1) // 2] / count var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean std = np.maximum(np.sqrt(var), eps) self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
[docs] def extra_repr(self): return ( f"stats_file={self.stats_file}, " f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" )
[docs] def forward( self, x: torch.Tensor, ilens: torch.LongTensor ) -> Tuple[torch.Tensor, torch.LongTensor]: # feat: (B, T, D) if self.norm_means: x += self.bias.type_as(x) x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) if self.norm_vars: x *= self.scale.type_as(x) return x, ilens
[docs]class UtteranceMVN(torch.nn.Module): def __init__( self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 ): super().__init__() self.norm_means = norm_means self.norm_vars = norm_vars self.eps = eps
[docs] def extra_repr(self): return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
[docs] def forward( self, x: torch.Tensor, ilens: torch.LongTensor ) -> Tuple[torch.Tensor, torch.LongTensor]: return utterance_mvn( x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps )
[docs]def utterance_mvn( x: torch.Tensor, ilens: torch.LongTensor, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, ) -> Tuple[torch.Tensor, torch.LongTensor]: """Apply utterance mean and variance normalization Args: x: (B, T, D), assumed zero padded ilens: (B, T, D) norm_means: norm_vars: eps: """ ilens_ = ilens.type_as(x) # mean: (B, D) mean = x.sum(dim=1) / ilens_[:, None] if norm_means: x -= mean[:, None, :] x_ = x else: x_ = x - mean[:, None, :] # Zero padding x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) if norm_vars: var = x_.pow(2).sum(dim=1) / ilens_[:, None] var = torch.clamp(var, min=eps) x /= var.sqrt()[:, None, :] x_ = x return x_, ilens
[docs]def feature_transform_for(args, n_fft): return FeatureTransform( # Mel options, fs=args.fbank_fs, n_fft=n_fft, n_mels=args.n_mels, fmin=args.fbank_fmin, fmax=args.fbank_fmax, # Normalization stats_file=args.stats_file, apply_uttmvn=args.apply_uttmvn, uttmvn_norm_means=args.uttmvn_norm_means, uttmvn_norm_vars=args.uttmvn_norm_vars, )