espnet2.asr.pit_espnet_model.PITLossWrapper
Less than 1 minute
espnet2.asr.pit_espnet_model.PITLossWrapper
class espnet2.asr.pit_espnet_model.PITLossWrapper(criterion_fn: Callable, num_ref: int)
Bases: AbsLossWrapper
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(inf: Tensor, inf_lens: Tensor, ref: Tensor, ref_lens: Tensor, others: Dict | None = None)
PITLoss Wrapper function. Similar to espnet2/enh/loss/wrapper/pit_solver.py
- Parameters:
- inf – Iterable[torch.Tensor], (batch, num_inf, …)
- inf_lens – Iterable[torch.Tensor], (batch, num_inf, …)
- ref – Iterable[torch.Tensor], (batch, num_ref, …)
- ref_lens – Iterable[torch.Tensor], (batch, num_ref, …)
- permute_inf – If true, permute the inference and inference_lens according to the optimal permutation.
classmethod permutate(perm, *args)