Source code for espnet2.s2st.synthesizer.discrete_synthesizer

# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Copyright 2022 Carnegie Mellon University (Jiatong Shi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Translatotron Synthesizer related modules for ESPnet2."""

from typing import Any, List, Optional, Tuple

import torch
import torch.nn.functional as F
from typeguard import typechecked

from espnet2.asr.decoder.transformer_decoder import TransformerDecoder
from espnet2.s2st.synthesizer.abs_synthesizer import AbsSynthesizer
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.scorer_interface import BatchScorerInterface


[docs]class TransformerDiscreteSynthesizer(AbsSynthesizer, BatchScorerInterface): """Discrete unit Synthesizer related modules for speech-to-speech translation. This is a module of discrete unit prediction network in discrete-unit described in `Direct speech-to-speech translation with discrete units`_, which converts the sequence of hidden states into the sequence of discrete unit (from SSLs). .. _`Direct speech-to-speech translation with discrete units`: https://arxiv.org/abs/2107.05604 """ @typechecked def __init__( self, # decoder related odim: int, idim: int, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, self_attention_dropout_rate: float = 0.0, src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, pos_enc_class=PositionalEncoding, normalize_before: bool = True, concat_after: bool = False, layer_drop_rate: float = 0.0, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "concat", ): """Transfomer decoder for discrete unit module. Args: vocab_size: output dim encoder_output_size: dimension of attention attention_heads: the number of heads of multi head attention linear_units: the number of units of position-wise feed forward num_blocks: the number of decoder blocks dropout_rate: dropout rate self_attention_dropout_rate: dropout rate for attention input_layer: input layer type use_output_layer: whether to use output layer pos_enc_class: PositionalEncoding or ScaledPositionalEncoding normalize_before: whether to use layer_norm before the first block concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type (str): How to integrate speaker embedding. """ super().__init__() self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, idim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, idim) self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is None: dec_idim = idim elif self.spk_embed_integration_type == "concat": dec_idim = idim + spk_embed_dim elif self.spk_embed_integration_type == "add": dec_idim = idim self.projection = torch.nn.Linear(self.spk_embed_dim, dec_idim) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") self.decoder = TransformerDecoder( vocab_size=odim, encoder_output_size=dec_idim, attention_heads=attention_heads, linear_units=linear_units, num_blocks=num_blocks, dropout_rate=dropout_rate, positional_dropout_rate=positional_dropout_rate, self_attention_dropout_rate=self_attention_dropout_rate, src_attention_dropout_rate=src_attention_dropout_rate, input_layer=input_layer, use_output_layer=use_output_layer, pos_enc_class=pos_enc_class, normalize_before=normalize_before, concat_after=concat_after, layer_drop_rate=layer_drop_rate, )
[docs] def forward( self, enc_outputs: torch.Tensor, enc_outputs_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, return_hs: bool = False, return_all_hs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: enc_outputs (LongTensor): Batch of padded character ids (B, T, idim). enc_outputs_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). Returns: hs hlens """ enc_outputs = enc_outputs[:, : enc_outputs_lengths.max()] feats = feats[:, : feats_lengths.max()] # for data-parallel ys = feats olens = feats_lengths # calculate hidden spaces for discrete unit outputs hs, hlens = self._forward( hs=enc_outputs, hlens=enc_outputs_lengths, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, return_hs=return_hs, return_all_hs=return_all_hs, ) return hs, hlens
def _forward( self, hs: torch.Tensor, hlens: torch.Tensor, ys: torch.Tensor, olens: torch.Tensor, spembs: torch.Tensor, sids: torch.Tensor, lids: torch.Tensor, return_hs: bool = False, return_all_hs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) return self.decoder( hs, hlens, ys, olens, return_hs=return_hs, return_all_hs=return_all_hs, ) def _integrate_with_spk_embed( self, hs: torch.Tensor, spembs: torch.Tensor ) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, eunits). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, eunits) if integration_type is "add" else (B, Tmax, eunits + spk_embed_dim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) else: raise NotImplementedError("support only add or concat.") return hs
[docs] def forward_one_step( self, tgt: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, cache: List[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward one step. Args: tgt: input token ids, int64 (batch, maxlen_out) tgt_mask: input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) memory: encoded memory, float32 (batch, maxlen_in, feat) cache: cached output list of (batch, max_time_out-1, size) Returns: y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ # FIXME(jiatong): the spk/lang embedding may be execute too many times # consider add before the search if self.spks is not None: sid_embs = self.sid_emb(self.spks.view(-1)) memory = memory + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(self.langs.view(-1)) memory = memory + lid_embs.unsqueeze(1) if self.spk_embed_dim is not None: memory = self._integrate_with_spk_embed(memory, self.spk_embed_dim) return self.decoder.forward_one_step(tgt, tgt_mask, memory, cache=cache)
[docs] def score(self, ys, state, x): """Score.""" ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) logp, state = self.forward_one_step( ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state ) return logp.squeeze(0), state
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.decoder.decoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list
[docs] def inference(self): pass