espnet2.enh.loss.wrappers.multilayer_pit_solver.MultiLayerPITSolver
Less than 1 minute
espnet2.enh.loss.wrappers.multilayer_pit_solver.MultiLayerPITSolver
class espnet2.enh.loss.wrappers.multilayer_pit_solver.MultiLayerPITSolver(criterion: AbsEnhLoss, weight=1.0, independent_perm=True, layer_weights=None)
Bases: AbsLossWrapper
Multi-Layer Permutation Invariant Training Solver.
Compute the PIT loss given inferences of multiple layers and a single reference. It also support single inference and single reference in evaluation stage.
- Parameters:
- criterion (AbsEnhLoss) – an instance of AbsEnhLoss
- weight (float) – weight (between 0 and 1) of current loss for multi-task learning.
- independent_perm (bool) – If True, PIT will be performed in forward to find the best permutation; If False, the permutation from the last LossWrapper output will be inherited. Note: You should be careful about the ordering of loss wrappers defined in the yaml config, if this argument is False.
- layer_weights (Optional *[*List *[*float ] ]) – weights for each layer If not None, the loss of each layer will be weighted-summed using the specified weights.
forward(ref, infs, others={})
Permutation invariant training solver.
- Parameters:
- ref (List *[*torch.Tensor ]) – [(batch, …), …] x n_spk
- infs (Union *[*List *[*torch.Tensor ] , List *[*List *[*torch.Tensor ] ] ]) – [(batch, …), …]
- Returns: (torch.Tensor): minimum loss with the best permutation stats: dict, for collecting training status others: dict, in this PIT solver, permutation order will be returned
- Return type: loss