espnet2.asr.state_spaces.attention.MultiHeadedAttention
espnet2.asr.state_spaces.attention.MultiHeadedAttention
class espnet2.asr.state_spaces.attention.MultiHeadedAttention(n_feat, n_head, dropout=0.0, transposed=False, **kwargs)
Bases: SequenceModule
Multi-Head Attention layer inheriting SequenceModule.
Comparing default MHA module in ESPnet, this module returns additional dummy state and has step function for autoregressive inference.
- Parameters:
- n_head (int) – The number of heads.
- n_feat (int) – The number of features.
- dropout_rate (float) – Dropout rate.
Construct an MultiHeadedAttention object.
forward(query, memory=None, mask=None, *args, **kwargs)
Compute scaled dot product attention.
- Parameters:
- query (torch.Tensor) – Query tensor (#batch, time1, size).
- key (torch.Tensor) – Key tensor (#batch, time2, size).
- value (torch.Tensor) – Value tensor (#batch, time2, size).
- mask (torch.Tensor) – Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
- Returns: Output tensor (#batch, time1, d_model).
- Return type: torch.Tensor
forward_attention(value, scores, mask)
Compute attention context vector.
- Parameters:
- value (torch.Tensor) – Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor) – Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor) – Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns: Transformed value (#batch, time1, d_model) : weighted by the attention score (#batch, time1, time2).
- Return type: torch.Tensor
forward_qkv(query, key, value)
Transform query, key and value.
- Parameters:
- query (torch.Tensor) – Query tensor (#batch, time1, size).
- key (torch.Tensor) – Key tensor (#batch, time2, size).
- value (torch.Tensor) – Value tensor (#batch, time2, size).
- Returns: Transformed query tensor (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- Return type: torch.Tensor
step(query, state, memory=None, mask=None, **kwargs)
Step the model recurrently for one step of the input sequence.
For example, this should correspond to unrolling an RNN for one step. If the forward pass has signature (B, L, H1) -> (B, L, H2), this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state.