Source code for espnet2.enh.espnet_model

"""Enhancement model module."""

import contextlib
from typing import Dict, List, Optional, OrderedDict, Tuple

import numpy as np
import torch
import torchaudio
from packaging.version import parse as V
from typeguard import typechecked

from espnet2.diar.layers.abs_mask import AbsMask
from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainLoss
from espnet2.enh.loss.criterions.time_domain import TimeDomainLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet2.enh.separator.bsrnn_separator import BSRNNSeparator
from espnet2.enh.separator.dan_separator import DANSeparator
from espnet2.enh.separator.tfgridnetv3_separator import TFGridNetV3
from espnet2.enh.separator.uses_separator import USESSeparator
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class ESPnetEnhancementModel(AbsESPnetModel): """Speech enhancement or separation Frontend model""" @typechecked def __init__( self, encoder: AbsEncoder, separator: Optional[AbsSeparator], decoder: AbsDecoder, mask_module: Optional[AbsMask], loss_wrappers: Optional[List[AbsLossWrapper]], stft_consistency: bool = False, loss_type: str = "mask_mse", mask_type: Optional[str] = None, flexible_numspk: bool = False, extract_feats_in_collect_stats: bool = False, normalize_variance: bool = False, normalize_variance_per_ch: bool = False, categories: list = [], category_weights: list = [], always_forward_in_48k: bool = False, ): """Main entry of speech enhancement/separation model training. Args: encoder: waveform encoder that converts waveforms to feature representations separator: separator that enhance or separate the feature representations decoder: waveform decoder that converts the feature back to waveforms mask_module: mask module that converts the feature to masks NOTE: Only used for compatibility with joint speaker diarization. See test/espnet2/enh/test_espnet_enh_s2t_model.py for details. loss_wrappers: list of loss wrappers Each loss wrapper contains a criterion for loss calculation and the corresonding loss weight. The losses will be calculated in the order of the list and summed up. ------------------------------------------------------------------ stft_consistency: (deprecated, kept for compatibility) whether to compute the TF-domain loss while enforcing STFT consistency NOTE: STFT consistency is now always used for frequency-domain spectrum losses. loss_type: (deprecated, kept for compatibility) loss type mask_type: (deprecated, kept for compatibility) mask type in TF-domain model ------------------------------------------------------------------ flexible_numspk: whether to allow the model to predict a variable number of speakers in its output. NOTE: This should be used when training a speech separation model for unknown number of speakers. ------------------------------------------------------------------ extract_feats_in_collect_stats: used in espnet2/tasks/abs_task.py for determining whether or not to skip model building in collect_stats stage (stage 5 in egs2/*/enh1/enh.sh). normalize_variance: whether to normalize the signal variance before model forward, and revert it back after. normalize_variance_per_ch: whether to normalize the signal variance for each channel instead of the whole signal. NOTE: normalize_variance and normalize_variance_per_ch cannot be True at the same time. ------------------------------------------------------------------ categories: list of all possible categories of minibatches (order matters!) (e.g. ["1ch_8k_reverb", "1ch_8k_both"] for multi-condition training) NOTE: this will be used to convert category index to the corresponding name for logging in forward_loss. Different categories will have different loss name suffixes. category_weights: list of weights for each category. Used to set loss weights for batches of different categories. ------------------------------------------------------------------ always_forward_in_48k: whether to always upsample the input speech to 48kHz for forward, and then downsample to the original sample rate for loss calculation. NOTE: this can be useful to train a model capable of handling various sampling rates while unifying bandwidth extension + speech enhancement. """ super().__init__() self.encoder = encoder self.separator = separator self.decoder = decoder self.mask_module = mask_module # set num_spk to -1 if None for compatibility with `espnet2.enh.diffusion_enh` self.num_spk = separator.num_spk if separator is not None else -1 # If True, self.num_spk is regarded as the MAXIMUM possible number of speakers self.flexible_numspk = flexible_numspk self.num_noise_type = getattr(self.separator, "num_noise_type", 1) self.loss_wrappers = loss_wrappers if self.loss_wrappers is not None: names = [w.criterion.name for w in self.loss_wrappers] if len(set(names)) != len(names): raise ValueError( "Duplicated loss names are not allowed: {}".format(names) ) # kept for compatibility self.mask_type = mask_type.upper() if mask_type else None self.loss_type = loss_type self.stft_consistency = stft_consistency # for multi-channel signal self.ref_channel = getattr(self.separator, "ref_channel", None) if self.ref_channel is None: self.ref_channel = 0 self.extract_feats_in_collect_stats = extract_feats_in_collect_stats self.normalize_variance = normalize_variance self.normalize_variance_per_ch = normalize_variance_per_ch if normalize_variance and normalize_variance_per_ch: raise ValueError( "normalize_variance and normalize_variance_per_ch cannot be True " "at the same time." ) # list all possible categories of the batch (order matters!) # (used to convert category index to the corresponding name for logging) self.categories = {} if categories: count = 0 for c in categories: if c not in self.categories: self.categories[count] = c count += 1 # used to set loss weights for batches of different categories if category_weights: assert len(category_weights) == len(self.categories) self.category_weights = tuple(category_weights) else: self.category_weights = tuple(1.0 for _ in self.categories) self.always_forward_in_48k = always_forward_in_48k
[docs] def forward( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor = None, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech_mix: (Batch, samples) or (Batch, samples, channels) speech_ref: (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_mix_lengths: (Batch,), default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py kwargs: "utt_id" is among the input. """ # reference speech signal of each speaker assert "speech_ref1" in kwargs, "At least 1 reference signal input is required." speech_ref = [ kwargs.get( f"speech_ref{spk + 1}", torch.zeros_like(kwargs["speech_ref1"]), ) for spk in range(self.num_spk) if f"speech_ref{spk + 1}" in kwargs ] num_spk = len(speech_ref) if self.flexible_numspk else self.num_spk assert len(speech_ref) == num_spk, (len(speech_ref), num_spk) # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) speech_ref = torch.stack(speech_ref, dim=1) if "noise_ref1" in kwargs: # noise signal (optional, required when using beamforming-based # frontend models) noise_ref = [ kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) ] # (Batch, num_noise_type, samples) or # (Batch, num_noise_type, samples, channels) noise_ref = torch.stack(noise_ref, dim=1) else: noise_ref = None # dereverberated (noisy) signal # (optional, only used for frontend models with WPE) if "dereverb_ref1" in kwargs: # noise signal (optional, required when using # frontend models with beamformering) dereverb_speech_ref = [ kwargs["dereverb_ref{}".format(n + 1)] for n in range(num_spk) if "dereverb_ref{}".format(n + 1) in kwargs ] assert len(dereverb_speech_ref) in (1, num_spk), len(dereverb_speech_ref) # (Batch, N, samples) or (Batch, N, samples, channels) dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1) else: dereverb_speech_ref = None batch_size = speech_mix.shape[0] speech_lengths = ( speech_mix_lengths if speech_mix_lengths is not None else torch.ones(batch_size).int().fill_(speech_mix.shape[1]) ) assert speech_lengths.dim() == 1, speech_lengths.shape # Check that batch_size is unified assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], ( speech_mix.shape, speech_ref.shape, speech_lengths.shape, ) # for data-parallel speech_ref = speech_ref[:, :, : speech_lengths.max()].unbind(dim=1) if noise_ref is not None: noise_ref = noise_ref[..., : speech_lengths.max()].unbind(dim=1) if dereverb_speech_ref is not None: dereverb_speech_ref = dereverb_speech_ref[..., : speech_lengths.max()] dereverb_speech_ref = dereverb_speech_ref.unbind(dim=1) # sampling frequency information about the batch fs = None fs_tuple = None speech_lengths0 = None if "utt2fs" in kwargs: # All samples must have the same sampling rate fs = kwargs["utt2fs"][0].item() assert all([fs == kwargs["utt2fs"][0].item() for fs in kwargs["utt2fs"]]) tgt_fs = 48000 if self.always_forward_in_48k else fs if fs != tgt_fs: fs_tuple = (tgt_fs, fs) speech_lengths0 = speech_lengths speech_lengths = speech_lengths.new_tensor( [ torchaudio.functional.resample( torch.randn(length, device="meta"), fs, tgt_fs ).size(0) for length in speech_lengths ] ) if speech_mix.ndim > 2: speech_mix = speech_mix.transpose(1, 2).contiguous() speech_mix = torchaudio.functional.resample(speech_mix, fs, tgt_fs) if speech_mix.ndim > 2: speech_mix = speech_mix.transpose(1, 2).contiguous() fs = tgt_fs # Adaptively adjust the STFT/iSTFT window/hop sizes for # BSRNNSeparator and USESSeparator if not isinstance( self.separator, (BSRNNSeparator, USESSeparator, TFGridNetV3) ): fs = None # category information (integer) about the batch category = kwargs.get("utt2category", None) if ( self.categories and category is not None and category[0].item() not in self.categories ): raise ValueError(f"Category '{category}' is not listed in self.categories") additional = {} # Additional data is required in Deep Attractor Network if isinstance(self.separator, DANSeparator): additional["feature_ref"] = [ self.encoder(r, speech_lengths, fs=fs)[0] for r in speech_ref ] if self.flexible_numspk: additional["num_spk"] = num_spk # Additional information is required in USES for multi-condition training if category is not None and isinstance(self.separator, USESSeparator): cat = self.categories[category[0].item()] if cat.endswith("_both"): additional["mode"] = "both" elif cat.endswith("_reverb"): additional["mode"] = "dereverb" else: additional["mode"] = "no_dereverb" speech_mix = speech_mix[:, : speech_lengths.max()] ################################### # Normalize the signal variance if self.normalize_variance_per_ch: dim = 1 mix_std_ = torch.std(speech_mix, dim=dim, keepdim=True) speech_mix = speech_mix / mix_std_ # RMS normalization elif self.normalize_variance: if speech_mix.ndim > 2: dim = (1, 2) else: dim = 1 mix_std_ = torch.std(speech_mix, dim=dim, keepdim=True) speech_mix = speech_mix / mix_std_ # RMS normalization # model forward speech_pre, feature_mix, feature_pre, others = self.forward_enhance( speech_mix, speech_lengths, additional, fs=fs ) ################################### # De-normalize the signal variance if self.normalize_variance_per_ch and speech_pre is not None: if mix_std_.ndim > 2: mix_std_ = mix_std_[:, :, self.ref_channel] speech_pre = [sp * mix_std_ for sp in speech_pre] elif self.normalize_variance and speech_pre is not None: if mix_std_.ndim > 2: mix_std_ = mix_std_.squeeze(2) speech_pre = [sp * mix_std_ for sp in speech_pre] # resample back to the original input sample rate if self.always_forward_in_48k and fs_tuple is not None: fs2, fs0 = fs_tuple speech_lengths = speech_lengths0 speech_pre = [ torchaudio.functional.resample(sp, fs2, fs0)[ ..., : speech_lengths0.max() ] for sp in speech_pre ] # loss computation loss, stats, weight, perm = self.forward_loss( speech_pre, speech_lengths, feature_mix, feature_pre, others, speech_ref, noise_ref, dereverb_speech_ref, category, num_spk=num_spk, fs=fs, ) return loss, stats, weight
[docs] def forward_enhance( self, speech_mix: torch.Tensor, speech_lengths: torch.Tensor, additional: Optional[Dict] = None, fs: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: feature_mix, flens = self.encoder(speech_mix, speech_lengths, fs=fs) if self.mask_module is None: feature_pre, flens, others = self.separator(feature_mix, flens, additional) else: # Obtain bottleneck_feats from separator. # This is used for the input of diarization module in "enh + diar" task bottleneck_feats, bottleneck_feats_lengths = self.separator( feature_mix, flens ) if additional.get("num_spk") is not None: feature_pre, flens, others = self.mask_module( feature_mix, flens, bottleneck_feats, additional["num_spk"] ) others["bottleneck_feats"] = bottleneck_feats others["bottleneck_feats_lengths"] = bottleneck_feats_lengths else: feature_pre = None others = { "bottleneck_feats": bottleneck_feats, "bottleneck_feats_lengths": bottleneck_feats_lengths, } if feature_pre is not None: # for models like SVoice that output multiple lists of separated signals pre_is_multi_list = isinstance(feature_pre[0], (list, tuple)) if pre_is_multi_list: speech_pre = [ [self.decoder(p, speech_lengths, fs=fs)[0] for p in ps] for ps in feature_pre ] else: speech_pre = [ self.decoder(ps, speech_lengths, fs=fs)[0] for ps in feature_pre ] else: # some models (e.g. neural beamformer trained with mask loss) # do not predict time-domain signal in the training stage speech_pre = None return speech_pre, feature_mix, feature_pre, others
[docs] def forward_loss( self, speech_pre: torch.Tensor, speech_lengths: torch.Tensor, feature_mix: torch.Tensor, feature_pre: List[torch.Tensor], others: OrderedDict, speech_ref: List[torch.Tensor], noise_ref: Optional[List[torch.Tensor]] = None, dereverb_speech_ref: Optional[List[torch.Tensor]] = None, category: Optional[torch.Tensor] = None, num_spk: Optional[int] = None, fs: Optional[int] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: # for calculating loss on estimated noise signals if getattr(self.separator, "predict_noise", False): assert "noise1" in others, others.keys() if noise_ref is not None and "noise1" in others: for n in range(self.num_noise_type): key = "noise{}".format(n + 1) others[key] = self.decoder(others[key], speech_lengths, fs=fs)[0] # for calculating loss on dereverberated signals if getattr(self.separator, "predict_dereverb", False): assert "dereverb1" in others, others.keys() if dereverb_speech_ref is not None and "dereverb1" in others: for spk in range(num_spk if num_spk else self.num_spk): key = "dereverb{}".format(spk + 1) if key in others: others[key] = self.decoder(others[key], speech_lengths, fs=fs)[0] loss = speech_ref[0].new_tensor(0.0) stats = {} o = {} perm = None for loss_wrapper in self.loss_wrappers: criterion = loss_wrapper.criterion only_for_test = getattr(criterion, "only_for_test", False) if only_for_test and self.training: continue is_noise_loss = getattr(criterion, "is_noise_loss", False) is_dereverb_loss = getattr(criterion, "is_dereverb_loss", False) if is_noise_loss: if noise_ref is None: raise ValueError( "No noise reference for training!\n" 'Please specify "--use_noise_ref true" in run.sh' ) signal_ref = noise_ref signal_pre = [ others["noise{}".format(n + 1)] for n in range(self.num_noise_type) ] elif is_dereverb_loss: if dereverb_speech_ref is None: raise ValueError( "No dereverberated reference for training!\n" 'Please specify "--use_dereverb_ref true" in run.sh' ) signal_ref = dereverb_speech_ref signal_pre = [ others["dereverb{}".format(n + 1)] for n in range(self.num_noise_type) if "dereverb{}".format(n + 1) in others ] if len(signal_pre) == 0: signal_pre = None else: signal_ref = speech_ref signal_pre = speech_pre zero_weight = loss_wrapper.weight == 0.0 if isinstance(criterion, TimeDomainLoss): if signal_pre is None: if is_noise_loss or is_dereverb_loss: # Skip loss computation for noise/dereverb-specific losses # if no noise/dereverb signals are predicted for this batch if category is not None: for idx, c in self.categories.items(): stats[criterion.name + "_" + c] = torch.full_like( loss, np.nan ) else: stats[criterion.name] = torch.full_like(loss, np.nan) continue raise ValueError( "Predicted waveform is required for time-domain loss." ) sref, spre = self._align_ref_pre_channels( signal_ref, signal_pre, ch_dim=2, force_1ch=True ) # for the time domain criterions with torch.no_grad() if zero_weight else contextlib.ExitStack(): l, s, o = loss_wrapper(sref, spre, {**others, **o}) elif isinstance(criterion, FrequencyDomainLoss): sref, spre = self._align_ref_pre_channels( signal_ref, signal_pre, ch_dim=2, force_1ch=False ) # for the time-frequency domain criterions if criterion.compute_on_mask: # compute loss on masks if getattr(criterion, "is_noise_loss", False): tf_ref, tf_pre = self._get_noise_masks( criterion, feature_mix, speech_ref, signal_ref, signal_pre, speech_lengths, others, fs=fs, ) elif getattr(criterion, "is_dereverb_loss", False): tf_ref, tf_pre = self._get_dereverb_masks( criterion, feature_mix, noise_ref, signal_ref, signal_pre, speech_lengths, others, fs=fs, ) else: tf_ref, tf_pre = self._get_speech_masks( criterion, feature_mix, noise_ref, signal_ref, signal_pre, speech_lengths, others, fs=fs, ) else: # compute on spectrum tf_ref = [self.encoder(sr, speech_lengths, fs=fs)[0] for sr in sref] # for models like SVoice that output multiple lists of # separated signals pre_is_multi_list = isinstance(spre[0], (list, tuple)) with torch.no_grad() if zero_weight else contextlib.ExitStack(): if pre_is_multi_list: tf_pre = [ [ self.encoder(sp, speech_lengths, fs=fs)[0] for sp in ps ] for ps in spre ] else: tf_pre = [ self.encoder(sp, speech_lengths, fs=fs)[0] for sp in spre ] with torch.no_grad() if zero_weight else contextlib.ExitStack(): l, s, o = loss_wrapper(tf_ref, tf_pre, {**others, **o}) else: raise NotImplementedError("Unsupported loss type: %s" % str(criterion)) loss += l * loss_wrapper.weight # rename the loss keys with a category prefix if ( self.categories and category is not None and category[0].item() not in self.categories ): raise ValueError( f"Category '{category}' is not listed in self.categories" ) if category is not None: for idx, c in self.categories.items(): if idx == category[0].item(): s[criterion.name + "_" + c] = s.pop(criterion.name) else: s[criterion.name + "_" + c] = torch.full_like(loss, np.nan) else: idx = 0 stats.update(s) loss *= self.category_weights[idx] if self.category_weights else 1.0 if perm is None and "perm" in o: perm = o["perm"] if self.training and not loss.requires_grad: raise AttributeError( "Loss must be a tensor with gradient in the training mode. " "Please check the following:\n" "1. At least one criterion must satisfy: only_for_test=False" "2. At least one criterion must always be computed in the training mode" " regardless of is_noise_loss=True or is_dereverb_loss=True" ) stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel batch_size = speech_ref[0].shape[0] loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight, perm
def _align_ref_pre_channels(self, ref, pre, ch_dim=2, force_1ch=False): if ref is None or pre is None: return ref, pre # NOTE: input must be a list of time-domain signals index = ref[0].new_tensor(self.ref_channel, dtype=torch.long) # for models like SVoice that output multiple lists of separated signals pre_is_multi_list = isinstance(pre[0], (list, tuple)) pre_dim = pre[0][0].dim() if pre_is_multi_list else pre[0].dim() if ref[0].dim() > pre_dim: # multi-channel reference and single-channel output ref = [r.index_select(ch_dim, index).squeeze(ch_dim) for r in ref] elif ref[0].dim() < pre_dim: # single-channel reference and multi-channel output if pre_is_multi_list: pre = [ p.index_select(ch_dim, index).squeeze(ch_dim) for plist in pre for p in plist ] else: pre = [p.index_select(ch_dim, index).squeeze(ch_dim) for p in pre] elif ref[0].dim() == pre_dim == 3 and force_1ch: # multi-channel reference and output ref = [r.index_select(ch_dim, index).squeeze(ch_dim) for r in ref] if pre_is_multi_list: pre = [ p.index_select(ch_dim, index).squeeze(ch_dim) for plist in pre for p in plist ] else: pre = [p.index_select(ch_dim, index).squeeze(ch_dim) for p in pre] return ref, pre def _get_noise_masks( self, criterion, feature_mix, speech_ref, noise_ref, noise_pre, ilens, others, fs=None, ): speech_spec = self.encoder(sum(speech_ref), ilens, fs=fs)[0] masks_ref = criterion.create_mask_label( feature_mix, [self.encoder(nr, ilens, fs=fs)[0] for nr in noise_ref], noise_spec=speech_spec, ) if "mask_noise1" in others: masks_pre = [ others["mask_noise{}".format(n + 1)] for n in range(self.num_noise_type) ] else: assert len(noise_pre) == len(noise_ref), (len(noise_pre), len(noise_ref)) masks_pre = criterion.create_mask_label( feature_mix, [self.encoder(np, ilens, fs=fs)[0] for np in noise_pre], noise_spec=speech_spec, ) return masks_ref, masks_pre def _get_dereverb_masks( self, criterion, feat_mix, noise_ref, dereverb_ref, dereverb_pre, ilens, others, fs=None, ): if noise_ref is not None: noise_spec = self.encoder(sum(noise_ref), ilens, fs=fs)[0] else: noise_spec = None masks_ref = criterion.create_mask_label( feat_mix, [self.encoder(dr, ilens, fs=fs)[0] for dr in dereverb_ref], noise_spec=noise_spec, ) if "mask_dereverb1" in others: masks_pre = [ others["mask_dereverb{}".format(spk + 1)] for spk in range(self.num_spk) if "mask_dereverb{}".format(spk + 1) in others ] assert len(masks_pre) == len(masks_ref), (len(masks_pre), len(masks_ref)) else: assert len(dereverb_pre) == len(dereverb_ref), ( len(dereverb_pre), len(dereverb_ref), ) masks_pre = criterion.create_mask_label( feat_mix, [self.encoder(dp, ilens, fs=fs)[0] for dp in dereverb_pre], noise_spec=noise_spec, ) return masks_ref, masks_pre def _get_speech_masks( self, criterion, feature_mix, noise_ref, speech_ref, speech_pre, ilens, others, fs=None, ): if noise_ref is not None: noise_spec = self.encoder(sum(noise_ref), ilens, fs=fs)[0] else: noise_spec = None masks_ref = criterion.create_mask_label( feature_mix, [self.encoder(sr, ilens, fs=fs)[0] for sr in speech_ref], noise_spec=noise_spec, ) if "mask_spk1" in others: masks_pre = [ others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk) if "mask_spk{}".format(spk + 1) in others ] else: masks_pre = criterion.create_mask_label( feature_mix, [self.encoder(sp, ilens, fs=fs)[0] for sp in speech_pre], noise_spec=noise_spec, ) return masks_ref, masks_pre
[docs] @staticmethod def sort_by_perm(nn_output, perm): """Sort the input list of tensors by the specified permutation. Args: nn_output: List[torch.Tensor(Batch, ...)], len(nn_output) == num_spk perm: (Batch, num_spk) or List[torch.Tensor(num_spk)] Returns: nn_output_new: List[torch.Tensor(Batch, ...)] """ if len(nn_output) == 1: return nn_output # (Batch, num_spk, ...) nn_output = torch.stack(nn_output, dim=1) if not isinstance(perm, torch.Tensor): # perm is a list or tuple perm = torch.stack(perm, dim=0) assert nn_output.size(1) == perm.size(1), (nn_output.shape, perm.shape) diff_dim = nn_output.dim() - perm.dim() if diff_dim > 0: perm = perm.view(*perm.shape, *[1 for _ in range(diff_dim)]).expand_as( nn_output ) return torch.gather(nn_output, 1, perm).unbind(dim=1)
[docs] def collect_feats( self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, **kwargs ) -> Dict[str, torch.Tensor]: # for data-parallel speech_mix = speech_mix[:, : speech_mix_lengths.max()] feats, feats_lengths = speech_mix, speech_mix_lengths return {"feats": feats, "feats_lengths": feats_lengths}