Source code for espnet2.gan_svs.vits.duration_predictor

# Copyright 2021 Tomoki Hayashi
# Copyright 2022 Yifeng Yu
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Duration predictor modules in VISinger.
"""

import torch

from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm


[docs]class DurationPredictor(torch.nn.Module): def __init__( self, channels, filter_channels, kernel_size, dropout_rate, global_channels=0, ): """Initialize duration predictor module. Args: channels (int): Number of input channels. filter_channels (int): Number of filter channels. kernel_size (int): Size of the convolutional kernel. dropout_rate (float): Dropout rate. global_channels (int, optional): Number of global conditioning channels. """ super().__init__() self.in_channels = channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.dropout_rate = dropout_rate self.drop = torch.nn.Dropout(dropout_rate) self.conv_1 = torch.nn.Conv1d( channels, filter_channels, kernel_size, padding=kernel_size // 2 ) self.norm_1 = LayerNorm(filter_channels, dim=1) self.conv_2 = torch.nn.Conv1d( filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 ) self.norm_2 = LayerNorm(filter_channels, dim=1) self.conv_3 = torch.nn.Conv1d( filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 ) self.norm_3 = LayerNorm(filter_channels, dim=1) self.proj = torch.nn.Conv1d(filter_channels, 2, 1) if global_channels > 0: self.conv = torch.nn.Conv1d(global_channels, channels, 1)
[docs] def forward(self, x, x_mask, g=None): """Forward pass through the duration predictor module. Args: x (Tensor): Input tensor (B, in_channels, T). x_mask (Tensor): Mask tensor (B, 1, T). g (Tensor, optional): Global condition tensor (B, global_channels, 1). Returns: Tensor: Predicted duration tensor (B, 2, T). """ # multi-singer if g is not None: g = torch.detach(g) x = x + self.conv(g) x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) x = self.drop(x) x = self.conv_2(x * x_mask) x = torch.relu(x) x = self.norm_2(x) x = self.drop(x) x = self.conv_3(x * x_mask) x = torch.relu(x) x = self.norm_3(x) x = self.drop(x) x = self.proj(x * x_mask) return x * x_mask