Source code for espnet2.asr.layers.fastformer

"""Fastformer attention definition.

Reference:
    Wu et al., "Fastformer: Additive Attention Can Be All You Need"
    https://arxiv.org/abs/2108.09084
    https://github.com/wuch15/Fastformer

"""

import numpy
import torch


[docs]class FastSelfAttention(torch.nn.Module): """Fast self-attention used in Fastformer.""" def __init__( self, size, attention_heads, dropout_rate, ): super().__init__() if size % attention_heads != 0: raise ValueError( f"Hidden size ({size}) is not an integer multiple " f"of attention heads ({attention_heads})" ) self.attention_head_size = size // attention_heads self.num_attention_heads = attention_heads self.query = torch.nn.Linear(size, size) self.query_att = torch.nn.Linear(size, attention_heads) self.key = torch.nn.Linear(size, size) self.key_att = torch.nn.Linear(size, attention_heads) self.transform = torch.nn.Linear(size, size) self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def espnet_initialization_fn(self): self.apply(self.init_weights)
[docs] def init_weights(self, module): if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, torch.nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs] def transpose_for_scores(self, x): """Reshape and transpose to compute scores. Args: x: (batch, time, size = n_heads * attn_dim) Returns: (batch, n_heads, time, attn_dim) """ new_x_shape = x.shape[:-1] + ( self.num_attention_heads, self.attention_head_size, ) return x.reshape(*new_x_shape).transpose(1, 2)
[docs] def forward(self, xs_pad, mask): """Forward method. Args: xs_pad: (batch, time, size = n_heads * attn_dim) mask: (batch, 1, time), nonpadding is 1, padding is 0 Returns: torch.Tensor: (batch, time, size) """ batch_size, seq_len, _ = xs_pad.shape mixed_query_layer = self.query(xs_pad) # (batch, time, size) mixed_key_layer = self.key(xs_pad) # (batch, time, size) if mask is not None: mask = mask.eq(0) # padding is 1, nonpadding is 0 # (batch, n_heads, time) query_for_score = ( self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5 ) if mask is not None: min_value = float( numpy.finfo( torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype ).min ) query_for_score = query_for_score.masked_fill(mask, min_value) query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0) else: query_weight = torch.softmax(query_for_score, dim=-1) query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time) query_layer = self.transpose_for_scores( mixed_query_layer ) # (batch, n_heads, time, attn_dim) pooled_query = ( torch.matmul(query_weight, query_layer) .transpose(1, 2) .reshape(-1, 1, self.num_attention_heads * self.attention_head_size) ) # (batch, 1, size = n_heads * attn_dim) pooled_query = self.dropout(pooled_query) pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size) mixed_query_key_layer = ( mixed_key_layer * pooled_query_repeat ) # (batch, time, size) # (batch, n_heads, time) query_key_score = ( self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5 ).transpose(1, 2) if mask is not None: min_value = float( numpy.finfo( torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype ).min ) query_key_score = query_key_score.masked_fill(mask, min_value) query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill( mask, 0.0 ) else: query_key_weight = torch.softmax(query_key_score, dim=-1) query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time) key_layer = self.transpose_for_scores( mixed_query_key_layer ) # (batch, n_heads, time, attn_dim) pooled_key = torch.matmul( query_key_weight, key_layer ) # (batch, n_heads, 1, attn_dim) pooled_key = self.dropout(pooled_key) # NOTE: value = query, due to param sharing weighted_value = (pooled_key * query_layer).transpose( 1, 2 ) # (batch, time, n_heads, attn_dim) weighted_value = weighted_value.reshape( weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,) ) # (batch, time, size) weighted_value = ( self.dropout(self.transform(weighted_value)) + mixed_query_layer ) return weighted_value