Source code for espnet2.enh.espnet_enh_s2t_model

import logging
import random
from contextlib import contextmanager
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from packaging.version import parse as V
from scipy.optimize import linear_sum_assignment
from typeguard import typechecked

from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.diar.espnet_model import ESPnetDiarizationModel
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.st.espnet_model import ESPnetSTModel
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


[docs]class ESPnetEnhS2TModel(AbsESPnetModel): """Joint model Enhancement and Speech to Text.""" @typechecked def __init__( self, enh_model: ESPnetEnhancementModel, s2t_model: Union[ESPnetASRModel, ESPnetSTModel, ESPnetDiarizationModel], calc_enh_loss: bool = True, bypass_enh_prob: float = 0, # 0 means do not bypass enhancement for all data ): super().__init__() self.enh_model = enh_model self.s2t_model = s2t_model # ASR or ST or DIAR model self.bypass_enh_prob = bypass_enh_prob self.calc_enh_loss = calc_enh_loss if isinstance(self.s2t_model, ESPnetDiarizationModel): self.extract_feats_in_collect_stats = False else: self.extract_feats_in_collect_stats = ( self.s2t_model.extract_feats_in_collect_stats ) if ( self.enh_model.num_spk is not None and self.enh_model.num_spk > 1 and isinstance(self.s2t_model, ESPnetASRModel) ): if self.calc_enh_loss: logging.warning("The permutation issue will be handled by the Enh loss") else: logging.warning("The permutation issue will be handled by the CTC loss")
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py For Enh+ASR task: text_spk1: (Batch, Length) text_spk2: (Batch, Length) ... text_spk1_lengths: (Batch,) text_spk2_lengths: (Batch,) ... For other tasks: text: (Batch, Length) default None just to keep the argument order text_lengths: (Batch,) default None for the same reason as speech_lengths """ if "text" in kwargs: text = kwargs["text"] text_ref_lengths = [kwargs.get("text_lengths", None)] if text_ref_lengths[0] is not None: text_length_max = max( ref_lengths.max() for ref_lengths in text_ref_lengths ) else: text_length_max = text.shape[1] else: text_ref = [ kwargs["text_spk{}".format(spk + 1)] for spk in range(self.enh_model.num_spk) ] text_ref_lengths = [ kwargs.get("text_spk{}_lengths".format(spk + 1), None) for spk in range(self.enh_model.num_spk) ] # for data-parallel if text_ref_lengths[0] is not None: text_length_max = max( ref_lengths.max() for ref_lengths in text_ref_lengths ) else: text_length_max = max(text.shape[1] for text in text_ref) # pad text sequences of different speakers to the same length ignore_id = getattr(self.s2t_model, "ignore_id", -1) text = torch.stack( [ F.pad(ref, (0, text_length_max - ref.shape[1]), value=ignore_id) for ref in text_ref ], dim=2, ) if text_ref_lengths[0] is not None: assert all(ref_lengths.dim() == 1 for ref_lengths in text_ref_lengths), ( ref_lengths.shape for ref_lengths in text_ref_lengths ) if speech_lengths is not None and text_ref_lengths[0] is not None: # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_ref_lengths[0].shape[0] ), ( speech.shape, speech_lengths.shape, text.shape, text_ref_lengths[0].shape, ) else: assert speech.shape[0] == text.shape[0], (speech.shape, text.shape) # additional checks with valid src_text if "src_text" in kwargs: src_text = kwargs["src_text"] src_text_lengths = kwargs["src_text_lengths"] if src_text is not None: assert src_text_lengths.dim() == 1, src_text_lengths.shape assert ( text_ref[0].shape[0] == src_text.shape[0] == src_text_lengths.shape[0] ), ( text_ref[0].shape, src_text.shape, src_text_lengths.shape, ) else: src_text = None src_text_lengths = None batch_size = speech.shape[0] speech_lengths = ( speech_lengths if speech_lengths is not None else torch.ones(batch_size).int() * speech.shape[1] ) # number of speakers # Take the number of speakers from text # (= spk_label [Batch, length, num_spk] ) if it is 3-D. # This is to handle flexible number of speakers. # Used only in "enh + diar" task for now. num_spk = text.shape[2] if text.dim() == 3 else self.enh_model.num_spk if self.enh_model.num_spk is not None: # for compatibility with TCNSeparatorNomask in enh_diar assert num_spk == self.enh_model.num_spk, (num_spk, self.enh_model.num_spk) # clean speech signal of each speaker speech_ref = None if self.calc_enh_loss: assert "speech_ref1" in kwargs speech_ref = [ kwargs["speech_ref{}".format(spk + 1)] for spk in range(num_spk) ] # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) # for data-parallel speech_ref = speech_ref[..., : speech_lengths.max()] speech_ref = speech_ref.unbind(dim=1) # Calculating enhancement loss utt_id = kwargs.get("utt_id", None) bypass_enh_flag, skip_enhloss_flag = False, False if utt_id is not None and not isinstance( self.s2t_model, ESPnetDiarizationModel ): # TODO(xkc): to pass category info and use predefined category list if utt_id[0].endswith("CLEAN"): # For clean data # feed it to Enhancement, without calculating loss_enh bypass_enh_flag = True skip_enhloss_flag = True elif utt_id[0].endswith("REAL"): # For single-speaker real data # feed it to Enhancement but without calculating loss_enh bypass_enh_flag = False skip_enhloss_flag = True else: # For simulated single-/multi-speaker data # feed it to Enhancement and calculate loss_enh bypass_enh_flag = False skip_enhloss_flag = False if not self.calc_enh_loss: skip_enhloss_flag = True # Bypass the enhancement module if ( self.training and skip_enhloss_flag and not bypass_enh_flag ): # For single-speaker real data: possibility to bypass frontend if random.random() <= self.bypass_enh_prob: bypass_enh_flag = True # 1. Enhancement # model forward loss_enh = None perm = None if not bypass_enh_flag: ret = self.enh_model.forward_enhance( speech, speech_lengths, {"num_spk": num_spk} ) speech_pre, feature_mix, feature_pre, others = ret # loss computation if not skip_enhloss_flag: loss_enh, _, _, perm = self.enh_model.forward_loss( speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref, ) loss_enh = loss_enh[0] # resort the prediction audios with the obtained permutation if perm is not None: speech_pre = ESPnetEnhancementModel.sort_by_perm(speech_pre, perm) else: speech_pre = [speech] # for data-parallel if text_ref_lengths[0] is not None: text = text[:, :text_length_max] if src_text is not None: src_text = src_text[:, : src_text_lengths.max()] # 2. ASR or ST if isinstance(self.s2t_model, ESPnetASRModel): # ASR if perm is None: loss_s2t, stats, weight = self.asr_pit_loss( speech_pre, speech_lengths, text.unbind(2), text_ref_lengths ) else: loss_s2t, stats, weight = self.s2t_model( torch.cat(speech_pre, dim=0), speech_lengths.repeat(len(speech_pre)), torch.cat(text.unbind(2), dim=0), torch.cat(text_ref_lengths, dim=0), ) stats["loss_asr"] = loss_s2t.detach() elif isinstance(self.s2t_model, ESPnetSTModel): # ST loss_s2t, stats, weight = self.s2t_model( speech_pre[0], speech_lengths, text, text_ref_lengths[0], src_text, src_text_lengths, ) stats["loss_st"] = loss_s2t.detach() elif isinstance(self.s2t_model, ESPnetDiarizationModel): # DIAR loss_s2t, stats, weight = self.s2t_model( speech=speech.clone(), speech_lengths=speech_lengths, spk_labels=text, spk_labels_lengths=text_ref_lengths[0], bottleneck_feats=others.get("bottleneck_feats"), bottleneck_feats_lengths=others.get("bottleneck_feats_lengths"), ) stats["loss_diar"] = loss_s2t.detach() else: raise NotImplementedError(f"{type(self.s2t_model)} is not supported yet.") if loss_enh is not None: loss = loss_enh + loss_s2t else: loss = loss_s2t stats["loss"] = loss.detach() if loss is not None else None stats["loss_enh"] = loss_enh.detach() if loss_enh is not None else None # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
[docs] def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, ) -> Dict[str, torch.Tensor]: if "text" in kwargs: text = kwargs["text"] text_lengths = kwargs.get("text_lengths", None) else: text = kwargs["text_spk1"] text_lengths = kwargs.get("text_spk1_lengths", None) if self.extract_feats_in_collect_stats: ret = self.s2t_model.collect_feats( speech, speech_lengths, text, text_lengths, **kwargs, ) feats, feats_lengths = ret["feats"], ret["feats_lengths"] else: # Generate dummy stats if extract_feats_in_collect_stats is False logging.warning( "Generating dummy stats for feats and feats_lengths, " "because encoder_conf.extract_feats_in_collect_stats is " f"{self.extract_feats_in_collect_stats}" ) feats, feats_lengths = speech, speech_lengths return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ ( speech_pre, feature_mix, feature_pre, others, ) = self.enh_model.forward_enhance(speech, speech_lengths) num_spk = len(speech_pre) assert num_spk == self.enh_model.num_spk, (num_spk, self.enh_model.num_spk) encoder_out, encoder_out_lens = zip( *[self.s2t_model.encode(sp, speech_lengths) for sp in speech_pre] ) return encoder_out, encoder_out_lens
[docs] def encode_diar( self, speech: torch.Tensor, speech_lengths: torch.Tensor, num_spk: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by diar_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) num_spk: int """ ( speech_pre, _, _, others, ) = self.enh_model.forward_enhance(speech, speech_lengths, {"num_spk": num_spk}) encoder_out, encoder_out_lens = self.s2t_model.encode( speech, speech_lengths, others.get("bottleneck_feats"), others.get("bottleneck_feats_lengths"), ) return encoder_out, encoder_out_lens, speech_pre
[docs] def nll( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ) -> torch.Tensor: """Compute negative log likelihood(nll) from transformer-decoder Normally, this function is called in batchify_nll. Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) ys_pad: (Batch, Length) ys_pad_lens: (Batch,) """ return self.s2t_model.nll( encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, )
batchify_nll = ESPnetASRModel.batchify_nll
[docs] def asr_pit_loss(self, speech, speech_lengths, text, text_lengths): if self.s2t_model.ctc is None: raise ValueError("CTC must be used to determine the permutation") with torch.no_grad(): # (B, n_ref, n_hyp) loss0 = torch.stack( [ torch.stack( [ self.s2t_model._calc_batch_ctc_loss( speech[h], speech_lengths, text[r], text_lengths[r], ) for r in range(self.enh_model.num_spk) ], dim=1, ) for h in range(self.enh_model.num_spk) ], dim=2, ) perm_detail, min_loss = self.permutation_invariant_training(loss0) speech = ESPnetEnhancementModel.sort_by_perm(speech, perm_detail) loss, stats, weight = self.s2t_model( torch.cat(speech, dim=0), speech_lengths.repeat(len(speech)), torch.cat(text, dim=0), torch.cat(text_lengths, dim=0), ) return loss, stats, weight
def _permutation_loss(self, ref, inf, criterion, perm=None): """The basic permutation loss function. Args: ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk inf (List[torch.Tensor]): [(batch, ...), ...] criterion (function): Loss function perm: (batch) Returns: loss: torch.Tensor: (batch) perm: list[(num_spk)] """ num_spk = len(ref) losses = torch.stack( [ torch.stack([criterion(ref[r], inf[h]) for r in range(num_spk)], dim=1) for h in range(num_spk) ], dim=2, ) # (B, n_ref, n_hyp) perm_detail, min_loss = self.permutation_invariant_training(losses) return min_loss.mean(), perm_detail
[docs] def permutation_invariant_training(self, losses: torch.Tensor): """Compute PIT loss. Args: losses (torch.Tensor): (batch, nref, nhyp) Returns: perm: list: (batch, n_spk) loss: torch.Tensor: (batch) """ hyp_perm, min_perm_loss = [], [] losses_cpu = losses.data.cpu() for b, b_loss in enumerate(losses_cpu): # hungarian algorithm try: row_ind, col_ind = linear_sum_assignment(b_loss) except ValueError as err: if str(err) == "cost matrix is infeasible": # random assignment since the cost is always inf col_ind = np.array([0, 1]) min_perm_loss.append(torch.mean(losses[b, col_ind, col_ind])) hyp_perm.append(col_ind) continue else: raise min_perm_loss.append(torch.mean(losses[b, row_ind, col_ind])) hyp_perm.append( torch.as_tensor(col_ind, dtype=torch.long, device=losses.device) ) return hyp_perm, torch.stack(min_perm_loss)
[docs] @typechecked def inherite_attributes( self, inherite_enh_attrs: List[str] = [], inherite_s2t_attrs: List[str] = [], ): if len(inherite_enh_attrs) > 0: for attr in inherite_enh_attrs: setattr(self, attr, getattr(self.enh_model, attr, None)) if len(inherite_s2t_attrs) > 0: for attr in inherite_s2t_attrs: setattr(self, attr, getattr(self.s2t_model, attr, None))