espnet2.enh.loss.wrappers.pit_solver.PITSolver
espnet2.enh.loss.wrappers.pit_solver.PITSolver
class espnet2.enh.loss.wrappers.pit_solver.PITSolver(criterion: AbsEnhLoss, weight=1.0, independent_perm=True, flexible_numspk=False)
Bases: AbsLossWrapper
Permutation Invariant Training Solver.
- Parameters:
criterion (AbsEnhLoss) – an instance of AbsEnhLoss
weight (float) – weight (between 0 and 1) of current loss for multi-task learning.
independent_perm (bool) –
If True, PIT will be performed in forward to find the best permutation; If False, the permutation from the last LossWrapper output will be inherited. NOTE (wangyou): You should be careful about the ordering of loss
wrappers defined in the yaml config, if this argument is False.
flexible_numspk (bool) – If True, num_spk will be taken from inf to handle flexible numbers of speakers. This is because ref may include dummy data in this case.
forward(ref, inf, others={})
PITSolver forward.
- Parameters:
- ref (List *[*torch.Tensor ]) – [(batch, …), …] x n_spk
- inf (List *[*torch.Tensor ]) – [(batch, …), …]
- Returns: (torch.Tensor): minimum loss with the best permutation stats: dict, for collecting training status others: dict, in this PIT solver, permutation order will be returned
- Return type: loss