Source code for espnet2.layers.global_mvn

from pathlib import Path
from typing import Tuple, Union

import numpy as np
import torch
from typeguard import typechecked

from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.inversible_interface import InversibleInterface
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask


[docs]class GlobalMVN(AbsNormalize, InversibleInterface): """Apply global mean and variance normalization TODO(kamo): Make this class portable somehow Args: stats_file: npy file norm_means: Apply mean normalization norm_vars: Apply var normalization eps: """ @typechecked def __init__( self, stats_file: Union[Path, str], norm_means: bool = True, norm_vars: bool = True, eps: float = 1.0e-20, ): super().__init__() self.norm_means = norm_means self.norm_vars = norm_vars self.eps = eps stats_file = Path(stats_file) self.stats_file = stats_file stats = np.load(stats_file) if isinstance(stats, np.ndarray): # Kaldi like stats count = stats[0].flatten()[-1] mean = stats[0, :-1] / count var = stats[1, :-1] / count - mean * mean else: # New style: Npz file count = stats["count"] sum_v = stats["sum"] sum_square_v = stats["sum_square"] mean = sum_v / count var = sum_square_v / count - mean * mean std = np.sqrt(np.maximum(var, eps)) if isinstance(mean, np.ndarray): mean = torch.from_numpy(mean) else: mean = torch.tensor(mean).float() if isinstance(std, np.ndarray): std = torch.from_numpy(std) else: std = torch.tensor(std).float() self.register_buffer("mean", mean) self.register_buffer("std", std)
[docs] def extra_repr(self): return ( f"stats_file={self.stats_file}, " f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" )
[docs] def forward( self, x: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward function Args: x: (B, L, ...) ilens: (B,) """ if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) norm_means = self.norm_means norm_vars = self.norm_vars self.mean = self.mean.to(x.device, x.dtype) self.std = self.std.to(x.device, x.dtype) mask = make_pad_mask(ilens, x, 1) # feat: (B, T, D) if norm_means: if x.requires_grad: x = x - self.mean else: x -= self.mean if x.requires_grad: x = x.masked_fill(mask, 0.0) else: x.masked_fill_(mask, 0.0) if norm_vars: x /= self.std return x, ilens
[docs] def inverse( self, x: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) norm_means = self.norm_means norm_vars = self.norm_vars self.mean = self.mean.to(x.device, x.dtype) self.std = self.std.to(x.device, x.dtype) mask = make_pad_mask(ilens, x, 1) if x.requires_grad: x = x.masked_fill(mask, 0.0) else: x.masked_fill_(mask, 0.0) if norm_vars: x *= self.std # feat: (B, T, D) if norm_means: x += self.mean x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) return x, ilens