Source code for espnet2.s2st.espnet_model

import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

import torch
from packaging.version import parse as V
from typeguard import typechecked

from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.encoder.conformer_encoder import ConformerEncoder
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.s2st.aux_attention.abs_aux_attention import AbsS2STAuxAttention
from espnet2.s2st.losses.abs_loss import AbsS2STLoss
from espnet2.s2st.synthesizer.abs_synthesizer import AbsSynthesizer
from espnet2.s2st.tgt_feats_extract.abs_tgt_feats_extract import AbsTgtFeatsExtract
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator
from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


[docs]class ESPnetS2STModel(AbsESPnetModel): """ESPnet speech-to-speech translation model""" @typechecked def __init__( self, s2st_type: str, frontend: Optional[AbsFrontend], tgt_feats_extract: Optional[AbsTgtFeatsExtract], specaug: Optional[AbsSpecAug], src_normalize: Optional[AbsNormalize], tgt_normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], asr_decoder: Optional[AbsDecoder], st_decoder: Optional[AbsDecoder], aux_attention: Optional[AbsS2STAuxAttention], unit_encoder: Optional[AbsEncoder], synthesizer: Optional[AbsSynthesizer], asr_ctc: Optional[CTC], st_ctc: Optional[CTC], losses: Dict[str, AbsS2STLoss], tgt_vocab_size: Optional[int], tgt_token_list: Optional[Union[Tuple[str, ...], List[str]]], src_vocab_size: Optional[int], src_token_list: Optional[Union[Tuple[str, ...], List[str]]], unit_vocab_size: Optional[int], unit_token_list: Optional[Union[Tuple[str, ...], List[str]]], ignore_id: int = -1, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", extract_feats_in_collect_stats: bool = True, ): super().__init__() self.sos = tgt_vocab_size - 1 if tgt_vocab_size else None self.eos = tgt_vocab_size - 1 if tgt_vocab_size else None self.src_sos = src_vocab_size - 1 if src_vocab_size else None self.src_eos = src_vocab_size - 1 if src_vocab_size else None self.unit_sos = unit_vocab_size - 1 if unit_vocab_size else None self.unit_eos = unit_vocab_size - 1 if unit_vocab_size else None self.tgt_vocab_size = tgt_vocab_size self.src_vocab_size = src_vocab_size self.unit_vocab_size = unit_vocab_size self.ignore_id = ignore_id self.tgt_token_list = tgt_token_list.copy() if tgt_token_list else None self.src_token_list = src_token_list.copy() if src_token_list else None self.unit_token_list = unit_token_list.copy() if unit_token_list else None self.s2st_type = s2st_type self.frontend = frontend self.tgt_feats_extract = tgt_feats_extract self.specaug = specaug self.src_normalize = src_normalize self.tgt_normalize = tgt_normalize self.preencoder = preencoder self.postencoder = postencoder self.encoder = encoder self.asr_decoder = asr_decoder self.st_decoder = st_decoder self.aux_attention = aux_attention self.unit_encoder = unit_encoder self.synthesizer = synthesizer self.asr_ctc = asr_ctc self.st_ctc = st_ctc self.losses = torch.nn.ModuleDict(losses) # ST error calculator if st_decoder and tgt_vocab_size and report_bleu: self.mt_error_calculator = MTErrorCalculator( tgt_token_list, sym_space, sym_blank, report_bleu ) else: self.mt_error_calculator = None # ASR error calculator if asr_decoder and src_vocab_size and (report_cer or report_wer): assert ( src_token_list is not None ), "Missing src_token_list, cannot add asr module to st model" self.asr_error_calculator = ASRErrorCalculator( src_token_list, sym_space, sym_blank, report_cer, report_wer ) else: self.asr_error_calculator = None self.extract_feats_in_collect_stats = extract_feats_in_collect_stats if self.s2st_type == "discrete_unit": assert isinstance(self.encoder, ConformerEncoder) or isinstance( self.encoder, TransformerEncoder ), "only support conformer or transformer-based encoder now" # synthesizer assert ( "synthesis" in self.losses ), "must have synthesis loss in the losses for S2ST"
[docs] def forward( self, src_speech: torch.Tensor, src_speech_lengths: torch.Tensor, tgt_speech: torch.Tensor, tgt_speech_lengths: torch.Tensor, tgt_text: Optional[torch.Tensor] = None, tgt_text_lengths: Optional[torch.Tensor] = None, src_text: Optional[torch.Tensor] = None, src_text_lengths: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, # TODO(Jiatong) sids: Optional[torch.Tensor] = None, # TODO(Jiatong) lids: Optional[torch.Tensor] = None, # TODO(Jiatong) **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: # TODO(jiatong): add comments etc. assert ( src_speech.shape[0] == src_speech_lengths.shape[0] == tgt_speech.shape[0] == tgt_speech_lengths.shape[0] ), ( src_speech.shape, src_speech_lengths.shape, tgt_speech.shape, tgt_speech_lengths.shape, ) # additional checks with valid tgt_text and src_text if tgt_text is not None: assert tgt_text_lengths.dim() == 1, tgt_text_lengths.shape assert ( src_speech.shape[0] == src_text.shape[0] == src_text_lengths.shape[0] == tgt_text.shape[0] == tgt_text_lengths.shape[0] ), ( src_speech.shape, src_text.shape, src_text_lengths.shape, tgt_text.shape, tgt_text_lengths.shape, ) batch_size = src_speech.shape[0] # for data-parallel src_speech = src_speech[:, : src_speech_lengths.max()] if src_text is not None: src_text = src_text[:, : src_text_lengths.max()] if tgt_text is not None: tgt_text = tgt_text[:, : tgt_text_lengths.max()] tgt_speech = tgt_speech[:, : tgt_speech_lengths.max()] # 0. Target feature extract # NOTE(jiatong): only for teaching-forcing in spectrogram if self.tgt_feats_extract is not None: tgt_feats, tgt_feats_lengths = self._extract_feats( tgt_speech, tgt_speech_lengths, target=True ) # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.tgt_normalize is not None: tgt_feats, tgt_feats_lengths = self.tgt_normalize( tgt_feats, tgt_feats_lengths ) else: # NOTE(jiatong): for discrete unit case tgt_feats, tgt_feats_lengths = tgt_speech, tgt_speech_lengths # 1. Encoder if self.s2st_type == "discrete_unit": (encoder_out, inter_encoder_out), encoder_out_lens = self.encode( src_speech, src_speech_lengths, return_all_hs=True ) else: encoder_out, encoder_out_lens = self.encode(src_speech, src_speech_lengths) loss_record = [] ######################## # Translaotron Forward # ######################## if self.s2st_type == "translatotron": # use a shared encoder with three decoders (i.e., asr, st, s2st) # reference https://arxiv.org/pdf/1904.06037.pdf # asr_ctc if self.asr_ctc is not None and "asr_ctc" in self.losses: asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, src_text, src_text_lengths, ctc_type="asr", ) loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight) else: asr_ctc_loss, cer_asr_ctc = None, None # asr decoder if self.asr_decoder is not None and "src_attn" in self.losses: ( src_attn_loss, acc_src_attn, cer_src_attn, wer_src_attn, ) = self._calc_asr_att_loss( encoder_out, encoder_out_lens, src_text, src_text_lengths ) loss_record.append(src_attn_loss * self.losses["src_attn"].weight) else: src_attn_loss, acc_src_attn, cer_src_attn, wer_src_attn = ( None, None, None, None, ) # st decoder if self.st_decoder is not None and "tgt_attn" in self.losses: tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn = self._calc_st_att_loss( encoder_out, encoder_out_lens, tgt_text, tgt_text_lengths ) loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight) else: tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn = None, None, None # NOTE(jiatong): the tgt_feats is also updated based on the reduction_factor ( after_outs, before_outs, logits, att_ws, updated_tgt_feats, stop_labels, updated_tgt_feats_lengths, ) = self.synthesizer( encoder_out, encoder_out_lens, tgt_feats, tgt_feats_lengths, spembs, sids, lids, ) syn_loss, l1_loss, mse_loss, bce_loss = self.losses["synthesis"]( after_outs, before_outs, logits, updated_tgt_feats, stop_labels, updated_tgt_feats_lengths, ) loss_record.append(syn_loss * self.losses["synthesis"].weight) # NOTE(jiatong): guided attention will be not used in multi-head attention if ( "syn_guided_attn" in self.losses and self.synthesizer.atype != "multihead" ): # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.synthesizer.reduction_factor > 1: updated_tgt_feats_lengths_in = updated_tgt_feats_lengths.new( [ olen // self.reduction_factor for olen in updated_tgt_feats_lengths ] ) else: updated_tgt_feats_lengths_in = updated_tgt_feats_lengths syn_guided_attn_loss = self.losses["syn_guided_attn"]( att_ws=att_ws, ilens=encoder_out_lens, olens_in=updated_tgt_feats_lengths_in, ) loss_record.append( syn_guided_attn_loss * self.losses["syn_guided_attn"].weight ) else: syn_guided_attn_loss = None loss = sum(loss_record) stats = dict( loss=loss.item(), asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None, cer_asr_ctc=cer_asr_ctc, src_attn_loss=( src_attn_loss.item() if src_attn_loss is not None else None ), acc_src_attn=acc_src_attn, cer_src_attn=cer_src_attn, wer_src_attn=wer_src_attn, tgt_attn_loss=( tgt_attn_loss.item() if tgt_attn_loss is not None else None ), acc_tgt_attn=acc_tgt_attn, bleu_tgt_attn=bleu_tgt_attn, syn_loss=syn_loss.item() if syn_loss is not None else None, syn_guided_attn_loss=( syn_guided_attn_loss.item() if syn_guided_attn_loss is not None else None ), syn_l1_loss=l1_loss.item(), syn_mse_loss=mse_loss.item(), syn_bce_loss=bce_loss.item(), ) ######################### # Translaotron2 Forward # ######################### elif self.s2st_type == "translatotron2": # use a sinlge decoder for synthesis # reference https://arxiv.org/pdf/2107.08661v5.pdf # asr_ctc if self.asr_ctc is not None and "asr_ctc" in self.losses: asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, src_text, src_text_lengths, ctc_type="asr", ) loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight) else: asr_ctc_loss, cer_asr_ctc = None, None # st decoder if self.st_decoder is not None and "tgt_attn" in self.losses: ( tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, decoder_out, _, ) = self._calc_st_att_loss( encoder_out, encoder_out_lens, tgt_text, tgt_text_lengths, return_hs=True, ) loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight) else: tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, decoder_out = ( None, None, None, None, ) assert ( self.aux_attention is not None ), "must have aux attention in translatotron loss" # NOTE(jiatong): tgt_text_lengths + 1 for <eos> encoder_out_mask = ( make_pad_mask(encoder_out_lens).to(encoder_out.device).unsqueeze(1) ) attention_out = self.aux_attention( decoder_out, encoder_out, encoder_out, mask=encoder_out_mask ) decoder_out = torch.cat((decoder_out, attention_out), dim=-1) # NOTE(jiatong): the tgt_feats is also updated based on the reduction_factor # TODO(jiatong): use non-attentive tacotron-based synthesizer ( after_outs, before_outs, logits, att_ws, updated_tgt_feats, stop_labels, updated_tgt_feats_lengths, ) = self.synthesizer( decoder_out, tgt_text_lengths + 1, # NOTE(jiatong): +1 for <eos> tgt_feats, tgt_feats_lengths, spembs, sids, lids, ) syn_loss, l1_loss, mse_loss, bce_loss = self.losses["synthesis"]( after_outs, before_outs, logits, updated_tgt_feats, stop_labels, updated_tgt_feats_lengths, ) # loss_record.append(syn_loss * self.losses["synthesis"].weight) loss = sum(loss_record) stats = dict( loss=loss.item(), asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None, cer_asr_ctc=cer_asr_ctc, tgt_attn_loss=( tgt_attn_loss.item() if tgt_attn_loss is not None else None ), acc_tgt_attn=acc_tgt_attn, bleu_tgt_attn=bleu_tgt_attn, syn_loss=syn_loss.item() if syn_loss is not None else None, syn_l1_loss=l1_loss.item() if l1_loss is not None else None, syn_mse_loss=mse_loss.item() if mse_loss is not None else None, syn_bce_loss=bce_loss.item() if bce_loss is not None else None, ) ######################### # Discrete unit Forward # ######################### elif self.s2st_type == "discrete_unit": # discrete unit-based synthesis # Reference: https://arxiv.org/pdf/2107.05604.pdf encoder_layer_for_asr = len(inter_encoder_out) // 2 encoder_layer_for_st = len(inter_encoder_out) * 2 // 3 # asr_ctc if self.asr_ctc is not None and "asr_ctc" in self.losses: asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss( inter_encoder_out[encoder_layer_for_asr], encoder_out_lens, src_text, src_text_lengths, ctc_type="asr", ) loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight) else: asr_ctc_loss, cer_asr_ctc = None, None # asr decoder if self.asr_decoder is not None and "src_attn" in self.losses: ( src_attn_loss, acc_src_attn, cer_src_attn, wer_src_attn, ) = self._calc_asr_att_loss( encoder_out, encoder_out_lens, src_text, src_text_lengths ) loss_record.append(src_attn_loss * self.losses["src_attn"].weight) else: src_attn_loss, acc_src_attn, cer_src_attn, wer_src_attn = ( None, None, None, None, ) # st decoder if self.st_decoder is not None and "tgt_attn" in self.losses: ( tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, ) = self._calc_st_att_loss( inter_encoder_out[encoder_layer_for_st], encoder_out_lens, tgt_text, tgt_text_lengths, ) loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight) else: tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, decoder_out = ( None, None, None, None, ) # synthesizer ( unit_attn_loss, acc_unit_attn, syn_hidden, syn_hidden_lengths, ) = self._calc_unit_att_loss( encoder_out, encoder_out_lens, tgt_speech, tgt_speech_lengths, return_all_hs=True, ) loss_record.append(unit_attn_loss * self.losses["synthesis"].weight) unit_decoder_layer_for_st = len(syn_hidden) // 2 # st_ctc if self.st_ctc is not None and "st_ctc" in self.losses: st_ctc_loss, cer_st_ctc = self._calc_ctc_loss( syn_hidden[unit_decoder_layer_for_st], tgt_speech_lengths + 1, tgt_text, tgt_text_lengths, ctc_type="st", ) loss_record.append(st_ctc_loss * self.losses["st_ctc"].weight) else: st_ctc_loss, cer_st_ctc = None, None loss = sum(loss_record) stats = dict( loss=loss.item(), asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None, cer_asr_ctc=cer_asr_ctc, src_attn_loss=( src_attn_loss.item() if src_attn_loss is not None else None ), acc_src_attn=acc_src_attn, cer_src_attn=cer_src_attn, wer_src_attn=wer_src_attn, tgt_attn_loss=( tgt_attn_loss.item() if tgt_attn_loss is not None else None ), acc_tgt_attn=acc_tgt_attn, bleu_tgt_attn=bleu_tgt_attn, st_ctc_loss=st_ctc_loss.item() if st_ctc_loss is not None else None, cer_st_ctc=cer_st_ctc, unit_attn_loss=( unit_attn_loss.item() if unit_attn_loss is not None else None ), acc_unit_attn=acc_unit_attn if acc_unit_attn is not None else None, ) ################# # Unity Forward # ################# elif self.s2st_type == "unity": # unity # Reference: https://arxiv.org/pdf/2212.08055.pdf # asr_ctc if self.asr_ctc is not None and "asr_ctc" in self.losses: asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, src_text, src_text_lengths, ctc_type="asr", ) loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight) else: asr_ctc_loss, cer_asr_ctc = None, None # st decoder assert ( self.st_decoder is not None and "tgt_attn" in self.losses ), "st_decoder is necessary for unity-based model" ( tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, decoder_out, _, ) = self._calc_st_att_loss( encoder_out, encoder_out_lens, tgt_text, tgt_text_lengths, return_hs=True, ) loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight) assert ( self.unit_encoder is not None ), "unit_encoder is necessary for unity-based model" unit_encoder_out, unit_encoder_out_lengths, _ = self.unit_encoder( decoder_out, tgt_text_lengths + 1 ) # synthesizer unit_attn_loss, acc_unit_attn = self._calc_unit_att_loss( unit_encoder_out, unit_encoder_out_lengths, tgt_speech, tgt_speech_lengths, ) loss_record.append(unit_attn_loss * self.losses["synthesis"].weight) loss = sum(loss_record) stats = dict( loss=loss.item(), asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None, cer_asr_ctc=cer_asr_ctc, tgt_attn_loss=( tgt_attn_loss.item() if tgt_attn_loss is not None else None ), acc_tgt_attn=acc_tgt_attn, bleu_tgt_attn=bleu_tgt_attn, unit_attn_loss=( unit_attn_loss.item() if unit_attn_loss is not None else None ), acc_unit_attn=acc_unit_attn if acc_unit_attn is not None else None, ) else: raise ValueError("Not supported s2st type {}") # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
[docs] @typechecked def inference( self, src_speech: torch.Tensor, src_speech_lengths: Optional[torch.Tensor] = None, tgt_speech: Optional[torch.Tensor] = None, tgt_speech_lengths: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, # TODO(Jiatong) sids: Optional[torch.Tensor] = None, # TODO(Jiatong) lids: Optional[torch.Tensor] = None, # TODO(Jiatong) threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: # 0. Target feature extract # NOTE(jiatong): only for teaching-forcing in spectrogram if tgt_speech is not None and self.tgt_feats_extract is not None: tgt_feats, tgt_feats_lengths = self._extract_feats( tgt_speech, tgt_speech_lengths, target=True ) # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.tgt_normalize is not None: tgt_feats, tgt_feats_lengths = self.tgt_normalize( tgt_feats, tgt_feats_lengths ) else: # NOTE(jiatong): for discrete unit case tgt_feats, tgt_feats_lengths = tgt_speech, tgt_speech_lengths # 1. Encoder encoder_out, _ = self.encode(src_speech, src_speech_lengths) # 2. Decoder if self.s2st_type == "translatotron": assert encoder_out.size(0) == 1 output_dict = self.synthesizer.inference( encoder_out[0], tgt_feats[0], spembs, sids, lids, threshold, minlenratio, maxlenratio, use_att_constraint, backward_window, forward_window, use_teacher_forcing, ) elif self.s2st_type == "translatotron2": assert encoder_out.size(0) == 1 output_dict = self.synthesizer.inference( encoder_out[0], tgt_feats[0], spembs, sids, lids, threshold, minlenratio, maxlenratio, use_att_constraint, backward_window, forward_window, use_teacher_forcing, ) else: raise ValueError("Not supported s2st type {}") if self.tgt_normalize is not None and output_dict.get("feat_gen") is not None: # NOTE: normalize.inverse is in-place operation feat_gen_denorm = self.tgt_normalize.inverse( output_dict["feat_gen"].clone()[None] )[0][0] output_dict.update(feat_gen_denorm=feat_gen_denorm) return output_dict
[docs] def collect_feats( self, src_speech: torch.Tensor, src_speech_lengths: torch.Tensor, tgt_speech: torch.Tensor, tgt_speech_lengths: torch.Tensor, **kwargs, ) -> Dict[str, torch.Tensor]: if self.extract_feats_in_collect_stats: src_feats, src_feats_lengths = self._extract_feats( src_speech, src_speech_lengths ) return_dict = { "src_feats": src_feats, "src_feats_lengths": src_feats_lengths, } if self.tgt_feats_extract is not None: tgt_feats, tgt_feats_lengths = self._extract_feats( tgt_speech, tgt_speech_lengths, target=True ) return_dict.update( tgt_feats=tgt_feats, tgt_feats_lengths=tgt_feats_lengths, ) else: # Generate dummy stats if extract_feats_in_collect_stats is False logging.warning( "Generating dummy stats for feats and feats_lengths, " "because encoder_conf.extract_feats_in_collect_stats is " f"{self.extract_feats_in_collect_stats}" ) return_dict = { "src_feats": src_speech, "tgt_feats": tgt_speech, "src_feats_lengths": src_speech_lengths, "tgt_feats_lengths": tgt_speech_lengths, } return return_dict
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, return_all_hs: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by st_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.src_normalize is not None: feats, feats_lengths = self.src_normalize(feats, feats_lengths) # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) encoder_out, encoder_out_lens, _ = self.encoder( feats, feats_lengths, return_all_hs=return_all_hs ) if return_all_hs: encoder_out, inter_encoder_out = encoder_out # Post-encoder, e.g. NLU if self.postencoder is not None: encoder_out, encoder_out_lens = self.postencoder( encoder_out, encoder_out_lens ) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) assert encoder_out.size(1) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) if return_all_hs: return (encoder_out, inter_encoder_out), encoder_out_lens else: return encoder_out, encoder_out_lens
def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, target: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: if speech_lengths is not None: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) if target: feats, feats_lengths = self.tgt_feats_extract(speech, speech_lengths) else: feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths def _calc_unit_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, return_hs: bool = False, return_all_hs: bool = False, ): ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.unit_sos, self.unit_eos, self.ignore_id ) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder ( decoder_outs, decoder_out_lengths, ) = self.synthesizer( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, spembs, sids, lids, return_hs=return_hs, return_all_hs=return_all_hs, ) if return_hs or return_all_hs: (decoder_out, decoder_hidden) = decoder_outs else: decoder_out = decoder_outs # 2. Compute attention loss loss_att = self.losses["synthesis"](decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.unit_vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) if return_hs or return_all_hs: return loss_att, acc_att, decoder_hidden, decoder_out_lengths else: return loss_att, acc_att def _calc_st_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, return_hs: bool = False, return_all_hs: bool = False, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 assert ( not return_hs or not return_all_hs ), "cannot return both last hiddens or all hiddens" # 1. Forward decoder ( decoder_outs, decoder_out_lengths, ) = self.st_decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, return_hs=return_hs, return_all_hs=return_all_hs, ) if return_hs or return_all_hs: (decoder_out, decoder_hidden) = decoder_outs else: decoder_out = decoder_outs # 2. Compute attention loss loss_att = self.losses["tgt_attn"](decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.tgt_vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # Compute cer/wer using attention-decoder if self.training or self.mt_error_calculator is None: bleu_att = None else: ys_hat = decoder_out.argmax(dim=-1) bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu()) if return_hs: return loss_att, acc_att, bleu_att, decoder_hidden, decoder_out_lengths else: return loss_att, acc_att, bleu_att def _calc_asr_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.src_sos, self.src_eos, self.ignore_id ) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.asr_decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens ) # 2. Compute attention loss loss_att = self.losses["src_attn"](decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.src_vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # Compute cer/wer using attention-decoder if self.training or self.asr_error_calculator is None: cer_att, wer_att = None, None else: ys_hat = decoder_out.argmax(dim=-1) cer_att, wer_att = self.asr_error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, cer_att, wer_att def _calc_ctc_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ctc_type: str, ): if ctc_type == "asr": ctc = self.asr_ctc elif ctc_type == "st": ctc = self.st_ctc else: raise RuntimeError( "Cannot recognize the ctc-type: need 'src'/'tgt', but found {}".format( ctc_type ) ) # Calc CTC loss loss_ctc = ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) # Calc CER using CTC cer_ctc = None if not self.training and self.asr_error_calculator is not None: ys_hat = ctc.argmax(encoder_out).data cer_ctc = self.asr_error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) return loss_ctc, cer_ctc @property def require_vocoder(self): """Return whether or not vocoder is required.""" return True