espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
About 1 min
espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
class espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention(n_head, n_feat, dropout_rate, qk_norm=False, use_flash_attn=False, causal=False, cross_attn=False, use_sdpa=False)
Bases: Module
Multi-Head Attention layer.
- Parameters:
- n_head (int) – The number of heads.
- n_feat (int) – The number of features.
- dropout_rate (float) – Dropout rate.
- qk_norm (bool) – Normalize q and k before dot product.
- use_flash_attn (bool) – Use flash_attn implementation.
- causal (bool) – Apply causal attention.
- cross_attn (bool) – Cross attention instead of self attention.
- use_sdpa (bool) – Use PyTorch’s scaled dot product attention.
Construct an MultiHeadedAttention object.
forward(query, key, value, mask, expand_kv=False)
Compute scaled dot product attention.
- Parameters:
- query (torch.Tensor) – Query tensor (#batch, time1, size).
- key (torch.Tensor) – Key tensor (#batch, time2, size).
- value (torch.Tensor) – Value tensor (#batch, time2, size).
- mask (torch.Tensor) – Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
- expand_kv (bool) – Used only for partially autoregressive (PAR) decoding. When set to True, Linear layers are computed only for the first batch. This is useful to reduce the memory usage during decoding when the batch size is #beam_size x #mask_count, which can be large. Typically, in single waveform inference of PAR, Linear layers should not be computed for all batches for source-attention.
- Returns: Output tensor (#batch, time1, d_model).
- Return type: torch.Tensor
forward_attention(value, scores, mask)
Compute attention context vector.
- Parameters:
- value (torch.Tensor) – Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor) – Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor) – Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns: Transformed value (#batch, time1, d_model) : weighted by the attention score (#batch, time1, time2).
- Return type: torch.Tensor
forward_qkv(query, key, value, expand_kv=False)
Transform query, key and value.
- Parameters:
- query (torch.Tensor) – Query tensor (#batch, time1, size).
- key (torch.Tensor) – Key tensor (#batch, time2, size).
- value (torch.Tensor) – Value tensor (#batch, time2, size).
- expand_kv (bool) – Used only for partially autoregressive (PAR) decoding.
- Returns: Transformed query tensor (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- Return type: torch.Tensor