Source code for espnet2.layers.time_warp

"""Time warp module."""

import torch

from espnet.nets.pytorch_backend.nets_utils import pad_list

DEFAULT_TIME_WARP_MODE = "bicubic"


[docs]def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): """Time warping using torch.interpolate. Args: x: (Batch, Time, Freq) window: time warp parameter mode: Interpolate mode """ # bicubic supports 4D or more dimension tensor org_size = x.size() if x.dim() == 3: # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) x = x[:, None] t = x.shape[2] if t - window <= window: return x.view(*org_size) center = torch.randint(window, t - window, (1,))[0] warped = torch.randint(center - window, center + window, (1,))[0] + 1 # left: (Batch, Channel, warped, Freq) # right: (Batch, Channel, time - warped, Freq) left = torch.nn.functional.interpolate( x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False ) right = torch.nn.functional.interpolate( x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False ) if x.requires_grad: x = torch.cat([left, right], dim=-2) else: x[:, :, :warped] = left x[:, :, warped:] = right return x.view(*org_size)
[docs]class TimeWarp(torch.nn.Module): """Time warping using torch.interpolate. Args: window: time warp parameter mode: Interpolate mode """ def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): super().__init__() self.window = window self.mode = mode
[docs] def extra_repr(self): return f"window={self.window}, mode={self.mode}"
[docs] def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None): """Forward function. Args: x: (Batch, Time, Freq) x_lengths: (Batch,) """ if x_lengths is None or all(le == x_lengths[0] for le in x_lengths): # Note that applying same warping for each sample y = time_warp(x, window=self.window, mode=self.mode) else: # FIXME(kamo): I have no idea to batchify Timewarp ys = [] for i in range(x.size(0)): _y = time_warp( x[i][None, : x_lengths[i]], window=self.window, mode=self.mode, )[0] ys.append(_y) y = pad_list(ys, 0.0) return y, x_lengths