Source code for espnet2.uasr.discriminator.conv_discriminator

import argparse
from typing import Dict, Optional

import torch
from typeguard import typechecked

from espnet2.uasr.discriminator.abs_discriminator import AbsDiscriminator
from espnet2.utils.types import str2bool


[docs]class SamePad(torch.nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0
[docs] def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x
[docs]class ConvDiscriminator(AbsDiscriminator): """convolutional discriminator for UASR.""" @typechecked def __init__( self, input_dim: int, cfg: Optional[Dict] = None, conv_channels: int = 384, conv_kernel: int = 8, conv_dilation: int = 1, conv_depth: int = 2, linear_emb: str2bool = False, causal: str2bool = True, max_pool: str2bool = False, act_after_linear: str2bool = False, dropout: float = 0.0, spectral_norm: str2bool = False, weight_norm: str2bool = False, ): super().__init__() if cfg is not None: cfg = argparse.Namespace(**cfg) self.conv_channels = cfg.discriminator_dim self.conv_kernel = cfg.discriminator_kernel self.conv_dilation = cfg.discriminator_dilation self.conv_depth = cfg.discriminator_depth self.linear_emb = cfg.discriminator_linear_emb self.causal = cfg.discriminator_causal self.max_pool = cfg.discriminator_max_pool self.act_after_linear = cfg.discriminator_act_after_linear self.dropout = cfg.discriminator_dropout self.spectral_norm = cfg.discriminator_spectral_norm self.weight_norm = cfg.discriminator_weight_norm else: self.conv_channels = conv_channels self.conv_kernel = conv_kernel self.conv_dilation = conv_dilation self.conv_depth = conv_depth self.linear_emb = linear_emb self.causal = causal self.max_pool = max_pool self.act_after_linear = act_after_linear self.dropout = dropout self.spectral_norm = spectral_norm self.weight_norm = weight_norm if self.causal: self.conv_padding = self.conv_kernel - 1 else: self.conv_padding = self.conv_kernel // 2 def make_conv( in_channel, out_channel, kernal_size, padding_size=0, dilation_value=1 ): conv = torch.nn.Conv1d( in_channel, out_channel, kernel_size=kernal_size, padding=padding_size, dilation=dilation_value, ) if self.spectral_norm: conv = torch.nn.utils.spectral_norm(conv) elif self.weight_norm: conv = torch.nn.utils.weight_norm(conv) return conv # initialize embedding if self.linear_emb: emb_net = [ make_conv( input_dim, self.conv_channels, 1, dilation_value=self.conv_dilation ) ] else: emb_net = [ make_conv( input_dim, self.conv_channels, self.conv_kernel, self.conv_padding, dilation_value=self.conv_dilation, ), SamePad(kernel_size=self.conv_kernel, causal=self.causal), ] if self.act_after_linear: emb_net.append(torch.nn.GELU()) # initialize inner conv inner_net = [ torch.nn.Sequential( make_conv( self.conv_channels, self.conv_channels, self.conv_kernel, self.conv_padding, dilation_value=self.conv_dilation, ), SamePad(kernel_size=self.conv_kernel, causal=self.causal), torch.nn.Dropout(self.dropout), torch.nn.GELU(), ) for _ in range(self.conv_depth - 1) ] inner_net += [ make_conv( self.conv_channels, 1, self.conv_kernel, self.conv_padding, dilation_value=1, ), SamePad(kernel_size=self.conv_kernel, causal=self.causal), ] self.net = torch.nn.Sequential( *emb_net, torch.nn.Dropout(dropout), *inner_net, )
[docs] @typechecked def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]): # (Batch, Time, Channel) -> (Batch, Channel, Time) x = x.transpose(1, 2) x = self.net(x) # (Batch, Channel, Time) -> (Batch, Time, Channel) x = x.transpose(1, 2) x_sz = x.size(1) if padding_mask is not None and padding_mask.any() and padding_mask.dim() > 1: padding_mask = padding_mask[:, : x.size(1)] padding_mask.to(x.device) x[padding_mask] = float("-inf") if self.max_pool else 0 x_sz = x_sz - padding_mask.sum(dim=-1) x = x.squeeze(-1) if self.max_pool: x, _ = x.max(dim=-1) else: x = x.sum(dim=-1) x = x / x_sz return x