espnet2.diar.attractor.rnn_attractor.RnnAttractor
Less than 1 minute
espnet2.diar.attractor.rnn_attractor.RnnAttractor
class espnet2.diar.attractor.rnn_attractor.RnnAttractor(encoder_output_size: int, layer: int = 1, unit: int = 512, dropout: float = 0.1, attractor_grad: bool = True)
Bases: AbsAttractor
encoder decoder attractor for speaker diarization
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(enc_input: Tensor, ilens: Tensor, dec_input: Tensor)
Forward.
- Parameters:
- enc_input (torch.Tensor) – hidden_space [Batch, T, F]
- ilens (torch.Tensor) – input lengths [Batch]
- dec_input (torch.Tensor) – decoder input (zeros) [Batch, num_spk + 1, F]
- Returns: [Batch, num_spk + 1, F] att_prob: [Batch, num_spk + 1, 1]
- Return type: attractor