Source code for espnet2.s2st.synthesizer.unity_synthesizer

# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Copyright 2022 Carnegie Mellon University (Jiatong Shi)
#  Apache 2.0  (

"""Translatotron Synthesizer related modules for ESPnet2."""

from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from typeguard import typechecked

from espnet2.asr.decoder.transformer_decoder import TransformerDecoder
from espnet2.s2st.synthesizer.abs_synthesizer import AbsSynthesizer
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding

[docs]class UnitYSynthesizer(AbsSynthesizer): """UnitY Synthesizer related modules for speech-to-speech translation. This is a module of discrete unit prediction network in discrete-unit described in `Direct speech-to-speech translation with discrete units`_, which converts the sequence of hidden states into the sequence of discrete unit (from SSLs). """ @typechecked def __init__( self, # decoder related vocab_size: int, encoder_output_size: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, layer_drop_rate: float = 0.0, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "concat", ): """Transfomer decoder for discrete unit module. Args: vocab_size: output dim encoder_output_size: dimension of attention attention_heads: the number of heads of multi head attention linear_units: the number of units of position-wise feed forward num_blocks: the number of decoder blocks dropout_rate: dropout rate self_attention_dropout_rate: dropout rate for attention input_layer: input layer type use_output_layer: whether to use output layer pos_enc_class: PositionalEncoding or ScaledPositionalEncoding normalize_before: whether to use layer_norm before the first block concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type (str): How to integrate speaker embedding. """ super().__init__() self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, encoder_output_size) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, encoder_output_size) self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is None: dec_idim = encoder_output_size elif self.spk_embed_integration_type == "concat": dec_idim = encoder_output_size + spk_embed_dim elif self.spk_embed_integration_type == "add": dec_idim = encoder_output_size self.projection = torch.nn.Linear(self.spk_embed_dim, encoder_output_size) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") self.decoder = TransformerDecoder( vocab_size=vocab_size, encoder_output_size=dec_idim, attention_heads=attention_heads, linear_units=linear_units, num_blocks=num_blocks, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, self_attention_dropout_rate=self_attention_dropout_rate, src_attention_dropout_rate=src_attention_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, concat_after=concat_after, layer_drop_rate=layer_drop_rate, )
[docs] def forward( self, enc_outputs: torch.Tensor, enc_outputs_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, return_last_hidden: bool = False, return_all_hiddens: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: enc_outputs (LongTensor): Batch of padded character ids (B, T, idim). enc_outputs_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). Returns: hs hlens """ enc_outputs = enc_outputs[:, : enc_outputs_lengths.max()] feats = feats[:, : feats_lengths.max()] # for data-parallel ys = feats olens = feats_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs hs, hlens = self._forward( hs=enc_outputs, hlens=enc_outputs_lengths, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, return_last_hidden=return_last_hidden, return_all_hiddens=return_all_hiddens, ) return hs, hlens
def _forward( self, hs: torch.Tensor, hlens: torch.Tensor, ys: torch.Tensor, olens: torch.Tensor, spembs: torch.Tensor, sids: torch.Tensor, lids: torch.Tensor, return_last_hidden: bool = False, return_all_hiddens: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) return self.decoder( hs, hlens, ys, olens, return_last_hidden=return_last_hidden, return_all_hiddens=return_all_hiddens, ) def _integrate_with_spk_embed( self, hs: torch.Tensor, spembs: torch.Tensor ) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, eunits). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, eunits) if integration_type is "add" else (B, Tmax, eunits + spk_embed_dim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs =[hs, spembs], dim=-1) else: raise NotImplementedError("support only add or concat.") return hs