Source code for espnet.nets.chainer_backend.transformer.attention

# encoding: utf-8
"""Class Declaration of Transformer's Attention."""

import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np

MIN_VALUE = float(np.finfo(np.float32).min)


[docs]class MultiHeadAttention(chainer.Chain): """Multi Head Attention Layer. Args: n_units (int): Number of input units. h (int): Number of attention heads. dropout (float): Dropout rate. initialW: Initializer to initialize the weight. initial_bias: Initializer to initialize the bias. :param int h: the number of heads :param int n_units: the number of features :param float dropout_rate: dropout rate """ def __init__(self, n_units, h=8, dropout=0.1, initialW=None, initial_bias=None): """Initialize MultiHeadAttention.""" super(MultiHeadAttention, self).__init__() assert n_units % h == 0 stvd = 1.0 / np.sqrt(n_units) with self.init_scope(): self.linear_q = L.Linear( n_units, n_units, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd), ) self.linear_k = L.Linear( n_units, n_units, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd), ) self.linear_v = L.Linear( n_units, n_units, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd), ) self.linear_out = L.Linear( n_units, n_units, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd), ) self.d_k = n_units // h self.h = h self.dropout = dropout self.attn = None
[docs] def forward(self, e_var, s_var=None, mask=None, batch=1): """Core function of the Multi-head attention layer. Args: e_var (chainer.Variable): Variable of input array. s_var (chainer.Variable): Variable of source array from encoder. mask (chainer.Variable): Attention mask. batch (int): Batch size. Returns: chainer.Variable: Outout of multi-head attention layer. """ xp = self.xp if s_var is None: # batch, head, time1/2, d_k) Q = self.linear_q(e_var).reshape(batch, -1, self.h, self.d_k) K = self.linear_k(e_var).reshape(batch, -1, self.h, self.d_k) V = self.linear_v(e_var).reshape(batch, -1, self.h, self.d_k) else: Q = self.linear_q(e_var).reshape(batch, -1, self.h, self.d_k) K = self.linear_k(s_var).reshape(batch, -1, self.h, self.d_k) V = self.linear_v(s_var).reshape(batch, -1, self.h, self.d_k) scores = F.matmul(F.swapaxes(Q, 1, 2), K.transpose(0, 2, 3, 1)) / np.sqrt( self.d_k ) if mask is not None: mask = xp.stack([mask] * self.h, axis=1) scores = F.where(mask, scores, xp.full(scores.shape, MIN_VALUE, "f")) self.attn = F.softmax(scores, axis=-1) p_attn = F.dropout(self.attn, self.dropout) x = F.matmul(p_attn, F.swapaxes(V, 1, 2)) x = F.swapaxes(x, 1, 2).reshape(-1, self.h * self.d_k) return self.linear_out(x)