Source code for espnet2.gan_tts.parallel_wavegan.upsample

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Upsampling module.

This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.

"""

from typing import Any, Dict, List, Optional

import numpy as np
import torch
import torch.nn.functional as F

from espnet2.gan_tts.wavenet.residual_block import Conv1d


[docs]class Stretch2d(torch.nn.Module): """Stretch2d module.""" def __init__(self, x_scale: int, y_scale: int, mode: str = "nearest"): """Initialize Stretch2d module. Args: x_scale (int): X scaling factor (Time axis in spectrogram). y_scale (int): Y scaling factor (Frequency axis in spectrogram). mode (str): Interpolation mode. """ super().__init__() self.x_scale = x_scale self.y_scale = y_scale self.mode = mode
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Calculate forward propagation. Args: x (Tensor): Input tensor (B, C, F, T). Returns: Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), """ return F.interpolate( x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode )
[docs]class Conv2d(torch.nn.Conv2d): """Conv2d module with customized initialization.""" def __init__(self, *args, **kwargs): """Initialize Conv2d module.""" super().__init__(*args, **kwargs)
[docs] def reset_parameters(self): """Reset parameters.""" self.weight.data.fill_(1.0 / np.prod(self.kernel_size)) if self.bias is not None: torch.nn.init.constant_(self.bias, 0.0)
[docs]class UpsampleNetwork(torch.nn.Module): """Upsampling network module.""" def __init__( self, upsample_scales: List[int], nonlinear_activation: Optional[str] = None, nonlinear_activation_params: Dict[str, Any] = {}, interpolate_mode: str = "nearest", freq_axis_kernel_size: int = 1, ): """Initialize UpsampleNetwork module. Args: upsample_scales (List[int]): List of upsampling scales. nonlinear_activation (Optional[str]): Activation function name. nonlinear_activation_params (Dict[str, Any]): Arguments for the specified activation function. interpolate_mode (str): Interpolation mode. freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. """ super().__init__() self.up_layers = torch.nn.ModuleList() for scale in upsample_scales: # interpolation layer stretch = Stretch2d(scale, 1, interpolate_mode) self.up_layers += [stretch] # conv layer assert ( freq_axis_kernel_size - 1 ) % 2 == 0, "Not support even number freq axis kernel size." freq_axis_padding = (freq_axis_kernel_size - 1) // 2 kernel_size = (freq_axis_kernel_size, scale * 2 + 1) padding = (freq_axis_padding, scale) conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) self.up_layers += [conv] # nonlinear if nonlinear_activation is not None: nonlinear = getattr(torch.nn, nonlinear_activation)( **nonlinear_activation_params ) self.up_layers += [nonlinear]
[docs] def forward(self, c: torch.Tensor) -> torch.Tensor: """Calculate forward propagation. Args: c : Input tensor (B, C, T_feats). Returns: Tensor: Upsampled tensor (B, C, T_wav). """ c = c.unsqueeze(1) # (B, 1, C, T) for f in self.up_layers: c = f(c) return c.squeeze(1) # (B, C, T')
[docs]class ConvInUpsampleNetwork(torch.nn.Module): """Convolution + upsampling network module.""" def __init__( self, upsample_scales: List[int], nonlinear_activation: Optional[str] = None, nonlinear_activation_params: Dict[str, Any] = {}, interpolate_mode: str = "nearest", freq_axis_kernel_size: int = 1, aux_channels: int = 80, aux_context_window: int = 0, ): """Initialize ConvInUpsampleNetwork module. Args: upsample_scales (list): List of upsampling scales. nonlinear_activation (Optional[str]): Activation function name. nonlinear_activation_params (Dict[str, Any]): Arguments for the specified activation function. mode (str): Interpolation mode. freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. aux_channels (int): Number of channels of pre-conv layer. aux_context_window (int): Context window size of the pre-conv layer. """ super().__init__() self.aux_context_window = aux_context_window # To capture wide-context information in conditional features kernel_size = 2 * aux_context_window + 1 # NOTE(kan-bayashi): Use pad here, which is not used in parallel_wavegan self.pad = torch.nn.ReplicationPad1d(aux_context_window) self.conv_in = Conv1d( aux_channels, aux_channels, kernel_size=kernel_size, bias=False, ) self.upsample = UpsampleNetwork( upsample_scales=upsample_scales, nonlinear_activation=nonlinear_activation, nonlinear_activation_params=nonlinear_activation_params, interpolate_mode=interpolate_mode, freq_axis_kernel_size=freq_axis_kernel_size, )
[docs] def forward(self, c: torch.Tensor) -> torch.Tensor: """Calculate forward propagation. Args: c (Tensor): Input tensor (B, C, T_feats). Returns: Tensor: Upsampled tensor (B, C, T_wav), where T_wav = T_feats * prod(upsample_scales). """ c = self.conv_in(self.pad(c)) return self.upsample(c)