Source code for espnet2.enh.loss.wrappers.mixit_solver

import itertools
from typing import Dict, List, Union

import torch
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.complex_utils import einsum as complex_einsum
from espnet2.enh.layers.complex_utils import stack as complex_stack
from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper


[docs]class MixITSolver(AbsLossWrapper): def __init__( self, criterion: AbsEnhLoss, weight: float = 1.0, ): """Mixture Invariant Training Solver. Args: criterion (AbsEnhLoss): an instance of AbsEnhLoss weight (float): weight (between 0 and 1) of current loss for multi-task learning. """ super().__init__() self.criterion = criterion self.weight = weight @property def name(self): return "mixit" def _complex_einsum(self, equation, *operands): for op in operands: if not isinstance(op, ComplexTensor): op = ComplexTensor(op, torch.zeros_like(op)) return complex_einsum(equation, *operands)
[docs] def forward( self, ref: Union[List[torch.Tensor], List[ComplexTensor]], inf: Union[List[torch.Tensor], List[ComplexTensor]], others: Dict = {}, ): """MixIT solver. Args: ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk inf (List[torch.Tensor]): [(batch, ...), ...] x n_est Returns: loss: (torch.Tensor): minimum loss with the best permutation stats: dict, for collecting training status others: dict, in this PIT solver, permutation order will be returned """ num_inf = len(inf) num_ref = num_inf // 2 device = ref[0].device is_complex = isinstance(ref[0], ComplexTensor) assert is_complex == isinstance(inf[0], ComplexTensor) if not is_complex: ref_tensor = torch.stack(ref[:num_ref], dim=1) # (batch, num_ref, ...) inf_tensor = torch.stack(inf, dim=1) # (batch, num_inf, ...) einsum_fn = torch.einsum else: ref_tensor = complex_stack(ref[:num_ref], dim=1) # (batch, num_ref, ...) inf_tensor = complex_stack(inf, dim=1) # (batch, num_inf, ...) einsum_fn = self._complex_einsum # all permutation assignments: # [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), ..., (1, 1, 1, 1)] all_assignments = list(itertools.product(range(num_ref), repeat=num_inf)) all_mixture_matrix = torch.stack( [ torch.nn.functional.one_hot( torch.tensor(asm, dtype=torch.int64, device=device), num_classes=num_ref, ).transpose(1, 0) for asm in all_assignments ], dim=0, ).to( inf_tensor.dtype ) # (num_ref ^ num_inf, num_ref, num_inf) # (num_ref ^ num_inf, batch, num_ref, seq_len, ...) if inf_tensor.dim() == 3: est_sum_mixture = einsum_fn("ari,bil->abrl", all_mixture_matrix, inf_tensor) elif inf_tensor.dim() > 3: est_sum_mixture = einsum_fn( "ari,bil...->abrl...", all_mixture_matrix, inf_tensor ) losses = [] for i in range(all_mixture_matrix.shape[0]): losses.append( sum( [ self.criterion(ref_tensor[:, s], est_sum_mixture[i, :, s]) for s in range(num_ref) ] ) / num_ref ) losses = torch.stack(losses, dim=0) # (num_ref ^ num_inf, batch) loss, perm = torch.min(losses, dim=0) # (batch) loss = loss.mean() perm = torch.index_select(all_mixture_matrix, 0, perm) if perm.is_complex(): perm = perm.real stats = dict() stats[f"{self.criterion.name}_{self.name}"] = loss.detach() return loss.mean(), stats, {"perm": perm}