espnet2.s2st.aux_attention.multihead.MultiHeadAttention
Less than 1 minute
espnet2.s2st.aux_attention.multihead.MultiHeadAttention
class espnet2.s2st.aux_attention.multihead.MultiHeadAttention(n_head: int = 4, n_feat: int = 512, dropout_rate: float = 0.0)
Bases: AbsS2STAuxAttention
Multihead Attention for S2ST.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(query: Tensor, key: Tensor, value: Tensor, mask: Tensor)
Forward.
- Parameters:
- query (torch.Tensor) – Query tensor (#batch, time1, size).
- key (torch.Tensor) – Key tensor (#batch, time2, size).
- value (torch.Tensor) – Value tensor (#batch, time2, size).
- mask (torch.Tensor) – Mask tensor (#batch, 1, time2) or (#batch, time1, time2).
- Returns: Output tensor (#batch, time1, d_model).
- Return type: torch.Tensor