# Copyright 2020 Tomoki Hayashi
#  Apache 2.0  (

"""Variance predictor related modules."""

import torch
from typeguard import typechecked

from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm

[docs]class VariancePredictor(torch.nn.Module): """Variance predictor module. This is a module of variacne predictor described in `FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`_. .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: """ @typechecked def __init__( self, idim: int, n_layers: int = 2, n_chans: int = 384, kernel_size: int = 3, bias: bool = True, dropout_rate: float = 0.5, ): """Initilize duration predictor module. Args: idim (int): Input dimension. n_layers (int): Number of convolutional layers. n_chans (int): Number of channels of convolutional layers. kernel_size (int): Kernel size of convolutional layers. dropout_rate (float): Dropout rate. """ super().__init__() self.conv = torch.nn.ModuleList() for idx in range(n_layers): in_chans = idim if idx == 0 else n_chans self.conv += [ torch.nn.Sequential( torch.nn.Conv1d( in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias, ), torch.nn.ReLU(), LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), ) ] self.linear = torch.nn.Linear(n_chans, 1)
[docs] def forward(self, xs: torch.Tensor, x_masks: torch.Tensor = None) -> torch.Tensor: """Calculate forward propagation. Args: xs (Tensor): Batch of input sequences (B, Tmax, idim). x_masks (ByteTensor): Batch of masks indicating padded part (B, Tmax). Returns: Tensor: Batch of predicted sequences (B, Tmax, 1). """ xs = xs.transpose(1, -1) # (B, idim, Tmax) for f in self.conv: xs = f(xs) # (B, C, Tmax) xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1) if x_masks is not None: xs = xs.masked_fill(x_masks, 0.0) return xs