Source code for espnet2.svs.naive_rnn.naive_rnn

# Copyright 2021 Carnegie Mellon University (Jiatong Shi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Naive-SVS related modules."""

from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
from typeguard import typechecked

from espnet2.svs.abs_svs import AbsSVS
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.torch_utils.initialize import initialize
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet


[docs]class NaiveRNNLoss(torch.nn.Module): """Loss function module for Tacotron2.""" def __init__(self, use_masking=True, use_weighted_masking=False): """Initialize Tactoron2 loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. """ super(NaiveRNNLoss, self).__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.l1_criterion = torch.nn.L1Loss(reduction=reduction) self.mse_criterion = torch.nn.MSELoss(reduction=reduction) # NOTE(kan-bayashi): register pre hook function for the compatibility # self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
[docs] def forward(self, after_outs, before_outs, ys, olens): """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. """ # make mask and apply it if self.use_masking: masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) ys = ys.masked_select(masks) after_outs = after_outs.masked_select(masks) before_outs = before_outs.masked_select(masks) # calculate loss l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys) mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( before_outs, ys ) # make weighted mask and apply it if self.use_weighted_masking: masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) weights = masks.float() / masks.sum(dim=1, keepdim=True).float() out_weights = weights.div(ys.size(0) * ys.size(2)) # apply weight l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum() mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum() return l1_loss, mse_loss
[docs]class NaiveRNN(AbsSVS): """NaiveRNN-SVS module. This is an implementation of naive RNN for singing voice synthesis The features are processed directly over time-domain from music score and predict the singing voice features """ @typechecked def __init__( self, # network structure related idim: int, odim: int, midi_dim: int = 129, embed_dim: int = 512, eprenet_conv_layers: int = 3, eprenet_conv_chans: int = 256, eprenet_conv_filts: int = 5, elayers: int = 3, eunits: int = 1024, ebidirectional: bool = True, midi_embed_integration_type: str = "add", dlayers: int = 3, dunits: int = 1024, dbidirectional: bool = True, postnet_layers: int = 5, postnet_chans: int = 256, postnet_filts: int = 5, use_batch_norm: bool = True, reduction_factor: int = 1, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", eprenet_dropout_rate: float = 0.5, edropout_rate: float = 0.1, ddropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", use_masking: bool = False, use_weighted_masking: bool = False, loss_type: str = "L1", ): """Initialize NaiveRNN module. Args: idim (int): Dimension of the label inputs. odim (int): Dimension of the outputs. midi_dim (int): Dimension of the midi inputs. embed_dim (int): Dimension of the token embedding. eprenet_conv_layers (int): Number of prenet conv layers. eprenet_conv_filts (int): Number of prenet conv filter size. eprenet_conv_chans (int): Number of prenet conv filter channels. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. ebidirectional (bool): If bidirectional in encoder. midi_embed_integration_type (str): how to integrate midi information, ("add" or "cat"). dlayers (int): Number of decoder lstm layers. dunits (int): Number of decoder lstm units. dbidirectional (bool): if bidirectional in decoder. postnet_layers (int): Number of postnet layers. postnet_filts (int): Number of postnet filter size. postnet_chans (int): Number of postnet filter channels. use_batch_norm (bool): Whether to use batch normalization. reduction_factor (int): Reduction factor. # extra embedding related spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type (str): How to integrate speaker embedding. eprenet_dropout_rate (float): Prenet dropout rate. edropout_rate (float): Encoder dropout rate. ddropout_rate (float): Decoder dropout rate. postnet_dropout_rate (float): Postnet dropout_rate. init_type (str): How to initialize transformer parameters. use_masking (bool): Whether to mask padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. loss_type (str): Loss function type ("L1", "L2", or "L1+L2"). """ super().__init__() # store hyperparameters self.idim = idim self.midi_dim = midi_dim self.eunits = eunits self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.loss_type = loss_type self.midi_embed_integration_type = midi_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet self.encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, eunits), ) self.midi_encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=midi_dim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, eunits), ) else: self.encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=eunits, padding_idx=self.padding_idx ) self.midi_encoder_input_layer = torch.nn.Embedding( num_embeddings=midi_dim, embedding_dim=eunits, padding_idx=self.padding_idx, ) self.encoder = torch.nn.LSTM( input_size=eunits, hidden_size=eunits, num_layers=elayers, batch_first=True, dropout=edropout_rate, bidirectional=ebidirectional, ) self.midi_encoder = torch.nn.LSTM( input_size=eunits, hidden_size=eunits, num_layers=elayers, batch_first=True, dropout=edropout_rate, bidirectional=ebidirectional, ) dim_direction = 2 if ebidirectional is True else 1 if self.midi_embed_integration_type == "add": self.midi_projection = torch.nn.Linear( eunits * dim_direction, eunits * dim_direction ) else: self.midi_projection = torch.nn.Linear( 2 * eunits * dim_direction, eunits * dim_direction ) self.decoder = torch.nn.LSTM( input_size=eunits, hidden_size=eunits, num_layers=dlayers, batch_first=True, dropout=ddropout_rate, bidirectional=dbidirectional, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, eunits * dim_direction) self.langs = None if langs is not None and langs > 1: # TODO(Yuning): not encode yet self.langs = langs self.lid_emb = torch.nn.Embedding(langs, eunits * dim_direction) # define projection layer self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear( self.spk_embed_dim, eunits * dim_direction ) else: self.projection = torch.nn.Linear( eunits * dim_direction + self.spk_embed_dim, eunits * dim_direction ) # define final projection self.feat_out = torch.nn.Linear(eunits * dim_direction, odim * reduction_factor) # define postnet self.postnet = ( None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, ) ) # define loss function self.criterion = NaiveRNNLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, ) # initialize parameters self._reset_parameters( init_type=init_type, ) def _reset_parameters(self, init_type): # initialize parameters if init_type != "pytorch": initialize(self, init_type)
[docs] def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, label: Optional[Dict[str, torch.Tensor]] = None, label_lengths: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, melody_lengths: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, pitch_lengths: Optional[torch.Tensor] = None, duration: Optional[Dict[str, torch.Tensor]] = None, duration_lengths: Optional[Dict[str, torch.Tensor]] = None, slur: torch.LongTensor = None, slur_lengths: torch.Tensor = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, flag_IsValid=False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, Lmax, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (B, Tmax). label_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded label ids (B, ). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (B, Tmax). melody_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded melody (B, ). pitch (FloatTensor): Batch of padded f0 (B, Tmax). pitch_lengths (LongTensor): Batch of the lengths of padded f0 (B, ). duration (Optional[Dict]): key is "lab", "score"; value (LongTensor): Batch of padded duration (B, Tmax). duration_lengths (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of the lengths of padded duration (B, ). slur (LongTensor): Batch of padded slur (B, Tmax). slur_lengths (LongTensor): Batch of the lengths of padded slur (B, ). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). GS Fix: arguements from forward func. V.S. **batch from espnet_model.py label == durations | phone sequence melody -> pitch sequence Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ label = label["lab"] midi = melody["lab"] label_lengths = label_lengths["lab"] midi_lengths = melody_lengths["lab"] text = text[:, : text_lengths.max()] # for data-parallel feats = feats[:, : feats_lengths.max()] # for data-parallel midi = midi[:, : midi_lengths.max()] # for data-parallel label = label[:, : label_lengths.max()] # for data-parallel batch_size = feats.size(0) label_emb = self.encoder_input_layer(label) # FIX ME: label Float to Int midi_emb = self.midi_encoder_input_layer(midi) label_emb = torch.nn.utils.rnn.pack_padded_sequence( label_emb, label_lengths.to("cpu"), batch_first=True, enforce_sorted=False ) midi_emb = torch.nn.utils.rnn.pack_padded_sequence( midi_emb, midi_lengths.to("cpu"), batch_first=True, enforce_sorted=False ) hs_label, (_, _) = self.encoder(label_emb) hs_midi, (_, _) = self.midi_encoder(midi_emb) hs_label, _ = torch.nn.utils.rnn.pad_packed_sequence(hs_label, batch_first=True) hs_midi, _ = torch.nn.utils.rnn.pad_packed_sequence(hs_midi, batch_first=True) if self.midi_embed_integration_type == "add": hs = hs_label + hs_midi hs = F.leaky_relu(self.midi_projection(hs)) else: hs = torch.cat((hs_label, hs_midi), dim=-1) hs = F.leaky_relu(self.midi_projection(hs)) # integrate spk & lang embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # (B, T_feats//r, odim * r) -> (B, T_feats//r * r, odim) before_outs = F.leaky_relu(self.feat_out(hs).view(hs.size(0), -1, self.odim)) # postnet -> (B, T_feats//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2) ).transpose(1, 2) # modifiy mod part of groundtruth if self.reduction_factor > 1: assert feats_lengths.ge( self.reduction_factor ).all(), "Output length must be greater than or equal to reduction factor." olens = feats_lengths.new( [olen - olen % self.reduction_factor for olen in feats_lengths] ) max_olen = max(olens) ys = feats[:, :max_olen] else: ys = feats olens = feats_lengths # calculate loss values l1_loss, l2_loss = self.criterion( after_outs[:, : olens.max()], before_outs[:, : olens.max()], ys, olens ) if self.loss_type == "L1": loss = l1_loss elif self.loss_type == "L2": loss = l2_loss elif self.loss_type == "L1+L2": loss = l1_loss + l2_loss else: raise ValueError("unknown --loss-type " + self.loss_type) stats = dict( loss=loss.item(), l1_loss=l1_loss.item(), l2_loss=l2_loss.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) if flag_IsValid is False: # training stage return loss, stats, weight else: # validation stage return loss, stats, weight, after_outs[:, : olens.max()], ys, olens
[docs] def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, label: Optional[Dict[str, torch.Tensor]] = None, melody: Optional[Dict[str, torch.Tensor]] = None, pitch: Optional[torch.Tensor] = None, duration: Optional[Dict[str, torch.Tensor]] = None, slur: Optional[Dict[str, torch.Tensor]] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, use_teacher_forcing: torch.Tensor = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (Tmax). feats (Tensor): Batch of padded target features (Lmax, odim). label (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded label ids (Tmax). melody (Optional[Dict]): key is "lab" or "score"; value (LongTensor): Batch of padded melody (Tmax). pitch (FloatTensor): Batch of padded f0 (Tmax). slur (LongTensor): Batch of padded slur (B, Tmax). duration (Optional[Dict]): key is "lab", "score"; value (LongTensor): Batch of padded duration (Tmax). spembs (Optional[Tensor]): Batch of speaker embeddings (spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (1). lids (Optional[Tensor]): Batch of language IDs (1). Returns: Dict[str, Tensor]: Output dict including the following items: * feat_gen (Tensor): Output sequence of features (T_feats, odim). """ label = label["lab"] midi = melody["lab"] label_emb = self.encoder_input_layer(label) midi_emb = self.midi_encoder_input_layer(midi) hs_label, (_, _) = self.encoder(label_emb) hs_midi, (_, _) = self.midi_encoder(midi_emb) if self.midi_embed_integration_type == "add": hs = hs_label + hs_midi hs = F.leaky_relu(self.midi_projection(hs)) else: hs = torch.cat((hs_label, hs_midi), dim=-1) hs = F.leaky_relu(self.midi_projection(hs)) # integrate spk & lang embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) if spembs is not None: spembs = spembs.unsqueeze(0) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # (B, T_feats//r, odim * r) -> (B, T_feats//r * r, odim) before_outs = F.leaky_relu(self.feat_out(hs).view(hs.size(0), -1, self.odim)) # postnet -> (B, T_feats//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet( before_outs.transpose(1, 2) ).transpose(1, 2) return dict( feat_gen=after_outs[0], prob=None, att_w=None ) # outs, probs, att_ws
def _integrate_with_spk_embed( self, hs: torch.Tensor, spembs: torch.Tensor ) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs