Source code for espnet2.asr.specaug.specaug

"""SpecAugment module."""

from typing import Optional, Sequence, Union

from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.mask_along_axis import MaskAlongAxis, MaskAlongAxisVariableMaxWidth
from espnet2.layers.time_warp import TimeWarp


[docs]class SpecAug(AbsSpecAug): """Implementation of SpecAug. Reference: Daniel S. Park et al. "SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition" .. warning:: When using cuda mode, time_warp doesn't have reproducibility due to `torch.nn.functional.interpolate`. """ def __init__( self, apply_time_warp: bool = True, time_warp_window: int = 5, time_warp_mode: str = "bicubic", apply_freq_mask: bool = True, freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), num_freq_mask: int = 2, apply_time_mask: bool = True, time_mask_width_range: Optional[Union[int, Sequence[int]]] = None, time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None, num_time_mask: int = 2, replace_with_zero: bool = True, ): if not apply_time_warp and not apply_time_mask and not apply_freq_mask: raise ValueError( "Either one of time_warp, time_mask, or freq_mask should be applied" ) if ( apply_time_mask and (time_mask_width_range is not None) and (time_mask_width_ratio_range is not None) ): raise ValueError( 'Either one of "time_mask_width_range" or ' '"time_mask_width_ratio_range" can be used' ) super().__init__() self.apply_time_warp = apply_time_warp self.apply_freq_mask = apply_freq_mask self.apply_time_mask = apply_time_mask if apply_time_warp: self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) else: self.time_warp = None if apply_freq_mask: self.freq_mask = MaskAlongAxis( dim="freq", mask_width_range=freq_mask_width_range, num_mask=num_freq_mask, replace_with_zero=replace_with_zero, ) else: self.freq_mask = None if apply_time_mask: if time_mask_width_range is not None: self.time_mask = MaskAlongAxis( dim="time", mask_width_range=time_mask_width_range, num_mask=num_time_mask, replace_with_zero=replace_with_zero, ) elif time_mask_width_ratio_range is not None: self.time_mask = MaskAlongAxisVariableMaxWidth( dim="time", mask_width_ratio_range=time_mask_width_ratio_range, num_mask=num_time_mask, replace_with_zero=replace_with_zero, ) else: raise ValueError( 'Either one of "time_mask_width_range" or ' '"time_mask_width_ratio_range" should be used.' ) else: self.time_mask = None
[docs] def forward(self, x, x_lengths=None): if self.time_warp is not None: x, x_lengths = self.time_warp(x, x_lengths) if self.freq_mask is not None: x, x_lengths = self.freq_mask(x, x_lengths) if self.time_mask is not None: x, x_lengths = self.time_mask(x, x_lengths) return x, x_lengths