Source code for espnet2.enh.espnet_model_tse

"""Enhancement model module."""

import contextlib
from typing import Dict, List, Optional, OrderedDict, Tuple

import torch
from typeguard import typechecked

from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.enh.extractor.abs_extractor import AbsExtractor
from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainLoss
from espnet2.enh.loss.criterions.time_domain import TimeDomainLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class ESPnetExtractionModel(AbsESPnetModel): """Target Speaker Extraction Frontend model""" @typechecked def __init__( self, encoder: AbsEncoder, extractor: AbsExtractor, decoder: AbsDecoder, loss_wrappers: List[AbsLossWrapper], num_spk: int = 1, flexible_numspk: bool = False, share_encoder: bool = True, extract_feats_in_collect_stats: bool = False, ): super().__init__() self.encoder = encoder self.extractor = extractor self.decoder = decoder # Whether to share encoder for both mixture and enrollment self.share_encoder = share_encoder self.num_spk = num_spk # If True, self.num_spk is regarded as the MAXIMUM possible number of speakers self.flexible_numspk = flexible_numspk self.loss_wrappers = loss_wrappers names = [w.criterion.name for w in self.loss_wrappers] if len(set(names)) != len(names): raise ValueError("Duplicated loss names are not allowed: {}".format(names)) for w in self.loss_wrappers: if getattr(w.criterion, "is_noise_loss", False): raise ValueError("is_noise_loss=True is not supported") elif getattr(w.criterion, "is_dereverb_loss", False): raise ValueError("is_dereverb_loss=True is not supported") # for multi-channel signal self.ref_channel = getattr(self.extractor, "ref_channel", None) if self.ref_channel is None: self.ref_channel = 0 # Used in espnet2/tasks/abs_task.py for determining whether or not to do # collect_feats during collect stats (stage 5). self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
[docs] def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref1: (Batch, samples) or (Batch, samples, channels) speech_ref2: (Batch, samples) or (Batch, samples, channels) ... speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py enroll_ref1: (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 1 enroll_ref2: (Batch, samples_aux) enrollment (raw audio or embedding) for speaker 2 ... kwargs: "utt_id" is among the input. """ # reference speech signal of each speaker assert "speech_ref1" in kwargs, "At least 1 reference signal input is required." speech_ref = [ kwargs.get( f"speech_ref{spk + 1}", torch.zeros_like(kwargs["speech_ref1"]), ) for spk in range(self.num_spk) if "speech_ref{}".format(spk + 1) in kwargs ] num_spk = len(speech_ref) if self.flexible_numspk else self.num_spk assert len(speech_ref) == num_spk, (len(speech_ref), num_spk) # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) batch_size = speech_mix.shape[0] assert "enroll_ref1" in kwargs, "At least 1 enrollment signal is required." # enrollment signal for each speaker (as the target) enroll_ref = [ # (Batch, samples_aux) kwargs["enroll_ref{}".format(spk + 1)] for spk in range(num_spk) if "enroll_ref{}".format(spk + 1) in kwargs ] enroll_ref_lengths = [ # (Batch,) kwargs.get( "enroll_ref{}_lengths".format(spk + 1), torch.ones(batch_size).int().fill_(enroll_ref[spk].size(1)), ) for spk in range(num_spk) if "enroll_ref{}".format(spk + 1) in kwargs ] speech_lengths = ( speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int().fill_(speech_mix.shape[1]) ) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) for aux in enroll_ref: assert aux.shape[0] == speech_mix.shape[0], (aux.shape, speech_mix.shape) # for data-parallel speech_ref = speech_ref[..., : speech_lengths.max()].unbind(dim=1) speech_mix = speech_mix[:, : speech_lengths.max()] enroll_ref = [ enroll_ref[spk][:, : enroll_ref_lengths[spk].max()] for spk in range(len(enroll_ref)) ] assert len(speech_ref) == len(enroll_ref), (len(speech_ref), len(enroll_ref)) additional = {} # Additional data for training the TSE model if self.flexible_numspk: additional["num_spk"] = num_spk # model forward speech_pre, feature_mix, feature_pre, others = self.forward_enhance( speech_mix, speech_lengths, enroll_ref, enroll_ref_lengths, additional ) # loss computation loss, stats, weight, perm = self.forward_loss( speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref, ) return loss, stats, weight
[docs] def forward_enhance( self, speech_mix: torch.Tensor, speech_lengths: torch.Tensor, enroll_ref: torch.Tensor, enroll_ref_lengths: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: feature_mix, flens = self.encoder(speech_mix, speech_lengths) if self.share_encoder: feature_aux, flens_aux = zip( *[ self.encoder(enroll_ref[spk], enroll_ref_lengths[spk]) for spk in range(len(enroll_ref)) ] ) else: feature_aux = enroll_ref flens_aux = enroll_ref_lengths feature_pre, _, others = zip( *[ self.extractor( feature_mix, flens, feature_aux[spk], flens_aux[spk], suffix_tag=f"_spk{spk + 1}", additional=additional, ) for spk in range(len(enroll_ref)) ] ) others = {k: v for dic in others for k, v in dic.items()} if feature_pre[0] is not None: speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre] else: # some models (e.g. neural beamformer trained with mask loss) # do not predict time-domain signal in the training stage speech_pre = None return speech_pre, feature_mix, feature_pre, others
[docs] def forward_loss( self, speech_pre: torch.Tensor, speech_lengths: torch.Tensor, feature_mix: torch.Tensor, feature_pre: torch.Tensor, others: OrderedDict, speech_ref: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: loss = 0.0 stats = {} o = {} perm = None for loss_wrapper in self.loss_wrappers: criterion = loss_wrapper.criterion if getattr(criterion, "only_for_test", False) and self.training: continue zero_weight = loss_wrapper.weight == 0.0 if isinstance(criterion, TimeDomainLoss): assert speech_pre is not None sref, spre = self._align_ref_pre_channels( speech_ref, speech_pre, ch_dim=2, force_1ch=True ) # for the time domain criterions with torch.no_grad() if zero_weight else contextlib.ExitStack(): l, s, o = loss_wrapper(sref, spre, {**others, **o}) elif isinstance(criterion, FrequencyDomainLoss): sref, spre = self._align_ref_pre_channels( speech_ref, speech_pre, ch_dim=2, force_1ch=False ) # for the time-frequency domain criterions if criterion.compute_on_mask: # compute loss on masks tf_ref, tf_pre = self._get_speech_masks( criterion, feature_mix, None, speech_ref, speech_pre, speech_lengths, others, ) else: # compute on spectrum tf_ref = [self.encoder(sr, speech_lengths)[0] for sr in sref] tf_pre = [self.encoder(sp, speech_lengths)[0] for sp in spre] with torch.no_grad() if zero_weight else contextlib.ExitStack(): l, s, o = loss_wrapper(tf_ref, tf_pre, {**others, **o}) else: raise NotImplementedError("Unsupported loss type: %s" % str(criterion)) loss += l * loss_wrapper.weight stats.update(s) if perm is None and "perm" in o: perm = o["perm"] if self.training and isinstance(loss, float): raise AttributeError( "At least one criterion must satisfy: only_for_test=False" ) stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel batch_size = speech_ref[0].shape[0] loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight, perm
def _align_ref_pre_channels(self, ref, pre, ch_dim=2, force_1ch=False): if ref is None or pre is None: return ref, pre # NOTE: input must be a list of time-domain signals index = ref[0].new_tensor(self.ref_channel, dtype=torch.long) # for models like SVoice that output multiple lists of separated signals pre_is_multi_list = isinstance(pre[0], (list, tuple)) pre_dim = pre[0][0].dim() if pre_is_multi_list else pre[0].dim() if ref[0].dim() > pre_dim: # multi-channel reference and single-channel output ref = [r.index_select(ch_dim, index).squeeze(ch_dim) for r in ref] elif ref[0].dim() < pre_dim: # single-channel reference and multi-channel output if pre_is_multi_list: pre = [ p.index_select(ch_dim, index).squeeze(ch_dim) for plist in pre for p in plist ] else: pre = [p.index_select(ch_dim, index).squeeze(ch_dim) for p in pre] elif ref[0].dim() == pre_dim == 3 and force_1ch: # multi-channel reference and output ref = [r.index_select(ch_dim, index).squeeze(ch_dim) for r in ref] if pre_is_multi_list: pre = [ p.index_select(ch_dim, index).squeeze(ch_dim) for plist in pre for p in plist ] else: pre = [p.index_select(ch_dim, index).squeeze(ch_dim) for p in pre] return ref, pre def _get_speech_masks( self, criterion, feature_mix, noise_ref, speech_ref, speech_pre, ilens, others ): if noise_ref is not None: noise_spec = self.encoder(sum(noise_ref), ilens)[0] else: noise_spec = None masks_ref = criterion.create_mask_label( feature_mix, [self.encoder(sr, ilens)[0] for sr in speech_ref], noise_spec=noise_spec, ) if "mask_spk1" in others: masks_pre = [ others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk) if "mask_dereverb{}".format(spk + 1) in others ] else: masks_pre = criterion.create_mask_label( feature_mix, [self.encoder(sp, ilens)[0] for sp in speech_pre], noise_spec=noise_spec, ) return masks_ref, masks_pre
[docs] def collect_feats( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, **kwargs ) -> Dict[str, torch.Tensor]: # for data-parallel speech_mix = speech_mix[:, : speech_mix_lengths.max()] feats, feats_lengths = speech_mix, speech_mix_lengths return {"feats": feats, "feats_lengths": feats_lengths}