Source code for espnet2.asr.state_spaces.attention

"""Multi-Head Attention layer definition."""

import math

import numpy
import torch
from torch import nn

from espnet2.asr.state_spaces.base import SequenceModule

[docs]class MultiHeadedAttention(SequenceModule): """Multi-Head Attention layer inheriting SequenceModule. Comparing default MHA module in ESPnet, this module returns additional dummy state and has step function for autoregressive inference. Args: n_head (int): The number of heads. n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ def __init__(self, n_feat, n_head, dropout=0.0, transposed=False, **kwargs): """Construct an MultiHeadedAttention object.""" super().__init__() assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head self.linear_q = nn.Linear(n_feat, n_feat) self.linear_k = nn.Linear(n_feat, n_feat) self.linear_v = nn.Linear(n_feat, n_feat) self.linear_out = nn.Linear(n_feat, n_feat) self.attn = None self.dropout = nn.Dropout(p=dropout) self.d_output = n_feat
[docs] def forward_qkv(self, query, key, value): """Transform query, key and value. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). Returns: torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) q = q.transpose(1, 2) # (batch, head, time1, d_k) k = k.transpose(1, 2) # (batch, head, time2, d_k) v = v.transpose(1, 2) # (batch, head, time2, d_k) return q, k, v
[docs] def forward_attention(self, value, scores, mask): """Compute attention context vector. Args: value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). Returns: torch.Tensor: Transformed value (#batch, time1, d_model) weighted by the attention score (#batch, time1, time2). """ n_batch = value.size(0) if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) min_value = float( numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min ) scores = scores.masked_fill(mask, min_value) self.attn = torch.softmax(scores, dim=-1).masked_fill( mask, 0.0 ) # (batch, head, time1, time2) else: self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) p_attn = self.dropout(self.attn) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) x = ( x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
[docs] def forward(self, query, memory=None, mask=None, *args, **kwargs): """Compute scaled dot product attention. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). Returns: torch.Tensor: Output tensor (#batch, time1, d_model). """ # self-attention if memory is None: memory = query q, k, v = self.forward_qkv(query=query, key=memory, value=memory) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), None
[docs] def step(self, query, state, memory=None, mask=None, **kwargs): if memory is None: memory = query return self.forward(query, memory, mask=mask, **kwargs)[0].squeeze(1), state