Source code for espnet2.asr.pit_espnet_model

import itertools
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from packaging.version import parse as V
from typeguard import typechecked

from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.espnet_model import ESPnetASRModel as SingleESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


[docs]class PITLossWrapper(AbsLossWrapper): def __init__(self, criterion_fn: Callable, num_ref: int): super().__init__() self.criterion_fn = criterion_fn self.num_ref = num_ref
[docs] def forward( self, inf: torch.Tensor, inf_lens: torch.Tensor, ref: torch.Tensor, ref_lens: torch.Tensor, others: Dict = None, ): """PITLoss Wrapper function. Similar to espnet2/enh/loss/wrapper/pit_solver.py Args: inf: Iterable[torch.Tensor], (batch, num_inf, ...) inf_lens: Iterable[torch.Tensor], (batch, num_inf, ...) ref: Iterable[torch.Tensor], (batch, num_ref, ...) ref_lens: Iterable[torch.Tensor], (batch, num_ref, ...) permute_inf: If true, permute the inference and inference_lens according to the optimal permutation. """ assert ( self.num_ref == inf.shape[1] == inf_lens.shape[1] == ref.shape[1] == ref_lens.shape[1] ), (self.num_ref, inf.shape, inf_lens.shape, ref.shape, ref_lens.shape) all_permutations = torch.as_tensor( list(itertools.permutations(range(self.num_ref), r=self.num_ref)) ) stats = defaultdict(list) def pre_hook(func, *args, **kwargs): ret = func(*args, **kwargs) for k, v in getattr(self.criterion_fn, "stats", {}).items(): stats[k].append(v) return ret def pair_loss(permutation): return sum( [ pre_hook( self.criterion_fn, inf[:, j], inf_lens[:, j], ref[:, i], ref_lens[:, i], ) for i, j in enumerate(permutation) ] ) / len(permutation) losses = torch.stack( [pair_loss(p) for p in all_permutations], dim=1 ) # (batch_size, num_perm) min_losses, min_ids = torch.min(losses, dim=1) min_ids = min_ids.cpu() # because all_permutations is a cpu tensor. opt_perm = all_permutations[min_ids] # (batch_size, num_ref) # Permute the inf and inf_lens according to the optimal perm return min_losses.mean(), opt_perm
[docs] @classmethod def permutate(self, perm, *args): ret = [] batch_size = None num_ref = None for arg in args: # (batch, num_inf, ...) if batch_size is None: batch_size, num_ref = arg.shape[:2] else: assert torch.Size([batch_size, num_ref]) == arg.shape[:2] ret.append( torch.stack( [arg[torch.arange(batch_size), perm[:, i]] for i in range(num_ref)], dim=1, ) ) return ret
[docs]class ESPnetASRModel(SingleESPnetASRModel): """CTC-attention hybrid Encoder-Decoder model""" @typechecked def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: Optional[AbsDecoder], ctc: CTC, joint_network: Optional[torch.nn.Module], ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", # In a regular ESPnet recipe, <sos> and <eos> are both "<sos/eos>" # Pretrained HF Tokenizer needs custom sym_sos and sym_eos sym_sos: str = "<sos/eos>", sym_eos: str = "<sos/eos>", extract_feats_in_collect_stats: bool = True, lang_token_id: int = -1, # num_inf: the number of inferences (= number of outputs of the model) # num_ref: the number of references (= number of groundtruth seqs) num_inf: int = 1, num_ref: int = 1, ): assert 0.0 < ctc_weight <= 1.0, ctc_weight assert interctc_weight == 0.0, "interctc is not supported for multispeaker ASR" super(ESPnetASRModel, self).__init__( vocab_size=vocab_size, token_list=token_list, frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, postencoder=postencoder, decoder=decoder, ctc=ctc, joint_network=joint_network, ctc_weight=ctc_weight, interctc_weight=interctc_weight, ignore_id=ignore_id, lsm_weight=lsm_weight, length_normalized_loss=length_normalized_loss, report_cer=report_cer, report_wer=report_wer, sym_space=sym_space, sym_blank=sym_blank, sym_sos=sym_sos, sym_eos=sym_eos, extract_feats_in_collect_stats=extract_feats_in_collect_stats, lang_token_id=lang_token_id, ) assert num_inf == num_ref, "Current PIT loss wrapper requires num_inf=num_ref" self.num_inf = num_inf self.num_ref = num_ref self.pit_ctc = PITLossWrapper(criterion_fn=self.ctc, num_ref=num_ref)
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) kwargs: "utt_id" is among the input. """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text_ref = [text] + [ kwargs["text_spk{}".format(spk + 1)] for spk in range(1, self.num_ref) ] text_ref_lengths = [text_lengths] + [ kwargs.get("text_spk{}_lengths".format(spk + 1), None) for spk in range(1, self.num_ref) ] assert all(ref_lengths.dim() == 1 for ref_lengths in text_ref_lengths), ( ref_lengths.shape for ref_lengths in text_ref_lengths ) text_lengths = torch.stack(text_ref_lengths, dim=1) # (batch, num_ref) text_length_max = text_lengths.max() # pad text sequences of different speakers to the same length text = torch.stack( [ torch.nn.functional.pad( ref, (0, text_length_max - ref.shape[1]), value=self.ignore_id ) for ref in text_ref ], dim=1, ) # (batch, num_ref, seq_len) # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) loss_att, acc_att, cer_att, wer_att = None, None, None, None loss_ctc, cer_ctc = None, None loss_transducer, cer_transducer, wer_transducer = None, None, None stats = dict() # 1. CTC branch if self.ctc_weight != 0.0: # CTC is computed twice # This 1st ctc calculation is only used to decide permutation _, perm = self.pit_ctc(encoder_out, encoder_out_lens, text, text_lengths) encoder_out, encoder_out_lens = PITLossWrapper.permutate( perm, encoder_out, encoder_out_lens ) if text.dim() == 3: # combine all speakers hidden vectors and labels. encoder_out = encoder_out.reshape(-1, *encoder_out.shape[2:]) encoder_out_lens = encoder_out_lens.reshape(-1) text = text.reshape(-1, text.shape[-1]) text_lengths = text_lengths.reshape(-1) # This 2nd ctc calculation is to compute the loss loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, text, text_lengths ) loss_ctc = loss_ctc.sum() # Collect CTC branch stats stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc if self.use_transducer_decoder: # 2a. Transducer decoder branch ( loss_transducer, cer_transducer, wer_transducer, ) = self._calc_transducer_loss( encoder_out, encoder_out_lens, text, ) if loss_ctc is not None: loss = loss_transducer + (self.ctc_weight * loss_ctc) else: loss = loss_transducer # Collect Transducer branch stats stats["loss_transducer"] = ( loss_transducer.detach() if loss_transducer is not None else None ) stats["cer_transducer"] = cer_transducer stats["wer_transducer"] = wer_transducer else: # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None stats["acc"] = acc_att stats["cer"] = cer_att stats["wer"] = wer_att # Collect total loss stats stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight