espnet.nets.pytorch_backend.e2e_asr_mix.PIT
Less than 1 minute
espnet.nets.pytorch_backend.e2e_asr_mix.PIT
class espnet.nets.pytorch_backend.e2e_asr_mix.PIT(num_spkrs)
Bases: object
Permutation Invariant Training (PIT) module.
- Parameters:num_spkrs (int) – number of speakers for PIT process (2 or 3)
Initialize PIT module.
min_pit_sample(loss)
Compute the PIT loss for each sample.
- Parameters:loss (1-D torch.Tensor) – list of losses for one sample, including [h1r1, h1r2, h2r1, h2r2] or [h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3]
:return minimum loss of best permutation :rtype torch.Tensor (1) :return the best permutation :rtype List: len=2
permutationDFS(source, start)
Get permutations with DFS.
The final result is all permutations of the ‘source’ sequence. e.g. [[1, 2], [2, 1]] or
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]]
- Parameters:
- source (np.ndarray) – (num_spkrs, 1), e.g. [1, 2, …, N]
- start (int) – the start point to permute
pit_process(losses)
Compute the PIT loss for a batch.
- Parameters:losses (torch.Tensor) – losses (B, 1|4|9)
:return minimum losses of a batch with best permutation :rtype torch.Tensor (B) :return the best permutation :rtype torch.LongTensor (B, 1|2|3)