Source code for espnet.nets.pytorch_backend.rnn.attentions

"""Attention modules for RNN."""

import math

import torch
import torch.nn.functional as F

from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device


def _apply_attention_constraint(
    e, last_attended_idx, backward_window=1, forward_window=3
):
    """Apply monotonic attention constraint.

    This function apply the monotonic attention constraint
    introduced in `Deep Voice 3: Scaling
    Text-to-Speech with Convolutional Sequence Learning`_.

    Args:
        e (Tensor): Attention energy before applying softmax (1, T).
        last_attended_idx (int): The index of the inputs of the last attended [0, T].
        backward_window (int, optional): Backward window size in attention constraint.
        forward_window (int, optional): Forward window size in attetion constraint.

    Returns:
        Tensor: Monotonic constrained attention energy (1, T).

    .. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
        https://arxiv.org/abs/1710.07654

    """
    if e.size(0) != 1:
        raise NotImplementedError("Batch attention constraining is not yet supported.")
    backward_idx = last_attended_idx - backward_window
    forward_idx = last_attended_idx + forward_window
    if backward_idx > 0:
        e[:, :backward_idx] = -float("inf")
    if forward_idx < e.size(1):
        e[:, forward_idx:] = -float("inf")
    return e


[docs]class NoAtt(torch.nn.Module): """No attention""" def __init__(self): super(NoAtt, self).__init__() self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.c = None
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.c = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): """NoAtt forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: dummy (does not use) :param torch.Tensor att_prev: dummy (does not use) :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # initialize attention weight with uniform dist. if att_prev is None: # if no bias, 0 0-pad goes 0 mask = 1.0 - make_pad_mask(enc_hs_len).float() att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1) att_prev = att_prev.to(self.enc_h) self.c = torch.sum( self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1 ) return self.c, att_prev
[docs]class AttDot(torch.nn.Module): """Dot product attention :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__(self, eprojs, dunits, att_dim, han_mode=False): super(AttDot, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): """AttDot forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: dummy (does not use) :param torch.Tensor att_prev: dummy (does not use) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weight (B x T_max) :rtype: torch.Tensor """ batch = enc_hs_pad.size(0) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h)) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) e = torch.sum( self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), dim=2, ) # utt x frame # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w
[docs]class AttAdd(torch.nn.Module): """Additive attention :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__(self, eprojs, dunits, att_dim, han_mode=False): super(AttAdd, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): """AttAdd forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: dummy (does not use) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w
[docs]class AttLoc(torch.nn.Module): """location-aware attention module. Reference: Attention-Based Models for Speech Recognition (https://arxiv.org/pdf/1506.07503.pdf) :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__( self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False ): super(AttLoc, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward( self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0, last_attended_idx=None, backward_window=1, forward_window=3, ): """Calculate AttLoc forward propagation. :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: previous attention weight (B x T_max) :param float scaling: scaling parameter before applying softmax :param torch.Tensor forward_window: forward window size when constraining attention :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: # if no bias, 0 0-pad goes 0 att_prev = 1.0 - make_pad_mask(enc_hs_len).to( device=dec_z.device, dtype=dec_z.dtype ) att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) # att_prev: utt x frame -> utt x 1 x 1 x frame # -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE: consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) # apply monotonic attention constraint (mainly for TTS) if last_attended_idx is not None: e = _apply_attention_constraint( e, last_attended_idx, backward_window, forward_window ) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w
[docs]class AttCov(torch.nn.Module): """Coverage mechanism attention Reference: Get To The Point: Summarization with Pointer-Generator Network (https://arxiv.org/abs/1704.04368) :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__(self, eprojs, dunits, att_dim, han_mode=False): super(AttCov, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.wvec = torch.nn.Linear(1, att_dim) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): """AttCov forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param list att_prev_list: list of previous attention weight :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: list of previous attention weights :rtype: list """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev_list is None: # if no bias, 0 0-pad goes 0 att_prev_list = to_device( enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()) ) att_prev_list = [ att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1) ] # att_prev_list: L' * [B x T] => cov_vec B x T cov_vec = sum(att_prev_list) # cov_vec: B x T => B x T x 1 => B x T x att_dim cov_vec = self.wvec(cov_vec.unsqueeze(-1)) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) att_prev_list += [w] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, att_prev_list
[docs]class AttLoc2D(torch.nn.Module): """2D location-aware attention This attention is an extended version of location aware attention. It take not only one frame before attention weights, but also earlier frames into account. :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param int att_win: attention window size (default=5) :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__( self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False ): super(AttLoc2D, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (att_win, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.aconv_chans = aconv_chans self.att_win = att_win self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): """AttLoc2D forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: previous attention weight (B x att_win x T_max) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x att_win x T_max) :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: # B * [Li x att_win] # if no bias, 0 0-pad goes 0 att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1) # att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax att_conv = self.loc_conv(att_prev.unsqueeze(1)) # att_conv: B x C x 1 x Tmax -> B x Tmax x C att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) # update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax # -> B x att_win x Tmax att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) att_prev = att_prev[:, 1:] return c, att_prev
[docs]class AttLocRec(torch.nn.Module): """location-aware recurrent attention This attention is an extended version of location aware attention. With the use of RNN, it take the effect of the history of attention weights into account. :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__( self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False ): super(AttLocRec, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): """AttLocRec forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param tuple att_prev_states: previous attention weight and lstm states ((B, T_max), ((B, att_dim), (B, att_dim))) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights and lstm states (w, (hx, cx)) ((B, T_max), ((B, att_dim), (B, att_dim))) :rtype: tuple """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) if att_prev_states is None: # initialize attention weight with uniform dist. # if no bias, 0 0-pad goes 0 att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) # initialize lstm states att_h = enc_hs_pad.new_zeros(batch, self.att_dim) att_c = enc_hs_pad.new_zeros(batch, self.att_dim) att_states = (att_h, att_c) else: att_prev = att_prev_states[0] att_states = att_prev_states[1] # B x 1 x 1 x T -> B x C x 1 x T att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # apply non-linear att_conv = F.relu(att_conv) # B x C x 1 x T -> B x C x 1 x 1 -> B x C att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) att_h, att_c = self.att_lstm(att_conv, att_states) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, (w, (att_h, att_c))
[docs]class AttCovLoc(torch.nn.Module): """Coverage mechanism location aware attention This attention is a combination of coverage and location-aware attentions. :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__( self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False ): super(AttCovLoc, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.aconv_chans = aconv_chans self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): """AttCovLoc forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param list att_prev_list: list of previous attention weight :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: list of previous attention weights :rtype: list """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev_list is None: # if no bias, 0 0-pad goes 0 mask = 1.0 - make_pad_mask(enc_hs_len).float() att_prev_list = [ to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) ] # att_prev_list: L' * [B x T] => cov_vec B x T cov_vec = sum(att_prev_list) # cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) att_prev_list += [w] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, att_prev_list
[docs]class AttMultiHeadDot(torch.nn.Module): """Multi head dot product attention Reference: Attention is all you need (https://arxiv.org/abs/1706.03762) :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v """ def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): super(AttMultiHeadDot, self).__init__() self.mlp_q = torch.nn.ModuleList() self.mlp_k = torch.nn.ModuleList() self.mlp_v = torch.nn.ModuleList() for _ in range(aheads): self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) self.dunits = dunits self.eprojs = eprojs self.aheads = aheads self.att_dim_k = att_dim_k self.att_dim_v = att_dim_v self.scaling = 1.0 / math.sqrt(att_dim_k) self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, **kwargs): """AttMultiHeadDot forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: dummy (does not use) :return: attention weighted encoder state (B x D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list """ batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [ torch.tanh(self.mlp_k[h](self.enc_h)) for h in range(self.aheads) ] if self.pre_compute_v is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)] if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) c = [] w = [] for h in range(self.aheads): e = torch.sum( self.pre_compute_k[h] * torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k), dim=2, ) # utt x frame # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w += [F.softmax(self.scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum( self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 ) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w
[docs]class AttMultiHeadAdd(torch.nn.Module): """Multi head additive attention Reference: Attention is all you need (https://arxiv.org/abs/1706.03762) This attention is multi head attention using additive attention for each head. :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v """ def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): super(AttMultiHeadAdd, self).__init__() self.mlp_q = torch.nn.ModuleList() self.mlp_k = torch.nn.ModuleList() self.mlp_v = torch.nn.ModuleList() self.gvec = torch.nn.ModuleList() for _ in range(aheads): self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] self.gvec += [torch.nn.Linear(att_dim_k, 1)] self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) self.dunits = dunits self.eprojs = eprojs self.aheads = aheads self.att_dim_k = att_dim_k self.att_dim_v = att_dim_v self.scaling = 1.0 / math.sqrt(att_dim_k) self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, **kwargs): """AttMultiHeadAdd forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: dummy (does not use) :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list """ batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in range(self.aheads)] if self.pre_compute_v is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)] if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) c = [] w = [] for h in range(self.aheads): e = self.gvec[h]( torch.tanh( self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) ) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w += [F.softmax(self.scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum( self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 ) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w
[docs]class AttMultiHeadLoc(torch.nn.Module): """Multi head location based attention Reference: Attention is all you need (https://arxiv.org/abs/1706.03762) This attention is multi head attention using location-aware attention for each head. :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v """ def __init__( self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False, ): super(AttMultiHeadLoc, self).__init__() self.mlp_q = torch.nn.ModuleList() self.mlp_k = torch.nn.ModuleList() self.mlp_v = torch.nn.ModuleList() self.gvec = torch.nn.ModuleList() self.loc_conv = torch.nn.ModuleList() self.mlp_att = torch.nn.ModuleList() for _ in range(aheads): self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] self.gvec += [torch.nn.Linear(att_dim_k, 1)] self.loc_conv += [ torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) ] self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) self.dunits = dunits self.eprojs = eprojs self.aheads = aheads self.att_dim_k = att_dim_k self.att_dim_v = att_dim_v self.scaling = 1.0 / math.sqrt(att_dim_k) self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0, **kwargs): """AttMultiHeadLoc forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B x D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list """ batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in range(self.aheads)] if self.pre_compute_v is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)] if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) if att_prev is None: att_prev = [] for _ in range(self.aheads): # if no bias, 0 0-pad goes 0 mask = 1.0 - make_pad_mask(enc_hs_len).float() att_prev += [ to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) ] c = [] w = [] for h in range(self.aheads): att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) att_conv = att_conv.squeeze(2).transpose(1, 2) att_conv = self.mlp_att[h](att_conv) e = self.gvec[h]( torch.tanh( self.pre_compute_k[h] + att_conv + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) ) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w += [F.softmax(scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum( self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 ) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w
[docs]class AttMultiHeadMultiResLoc(torch.nn.Module): """Multi head multi resolution location based attention Reference: Attention is all you need (https://arxiv.org/abs/1706.03762) This attention is multi head attention using location-aware attention for each head. Furthermore, it uses different filter size for each head. :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param int aconv_chans: maximum # channels of attention convolution each head use #ch = aconv_chans * (head + 1) / aheads e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100 :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v """ def __init__( self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False, ): super(AttMultiHeadMultiResLoc, self).__init__() self.mlp_q = torch.nn.ModuleList() self.mlp_k = torch.nn.ModuleList() self.mlp_v = torch.nn.ModuleList() self.gvec = torch.nn.ModuleList() self.loc_conv = torch.nn.ModuleList() self.mlp_att = torch.nn.ModuleList() for h in range(aheads): self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] self.gvec += [torch.nn.Linear(att_dim_k, 1)] afilts = aconv_filts * (h + 1) // aheads self.loc_conv += [ torch.nn.Conv2d( 1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False ) ] self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) self.dunits = dunits self.eprojs = eprojs self.aheads = aheads self.att_dim_k = att_dim_k self.att_dim_v = att_dim_v self.scaling = 1.0 / math.sqrt(att_dim_k) self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_k = None self.pre_compute_v = None self.mask = None
[docs] def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, **kwargs): """AttMultiHeadMultiResLoc forward :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads :return: attention weighted encoder state (B x D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list """ batch = enc_hs_pad.size(0) # pre-compute all k and v outside the decoder loop if self.pre_compute_k is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in range(self.aheads)] if self.pre_compute_v is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)] if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) if att_prev is None: att_prev = [] for _ in range(self.aheads): # if no bias, 0 0-pad goes 0 mask = 1.0 - make_pad_mask(enc_hs_len).float() att_prev += [ to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) ] c = [] w = [] for h in range(self.aheads): att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) att_conv = att_conv.squeeze(2).transpose(1, 2) att_conv = self.mlp_att[h](att_conv) e = self.gvec[h]( torch.tanh( self.pre_compute_k[h] + att_conv + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) ) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w += [F.softmax(self.scaling * e, dim=1)] # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c += [ torch.sum( self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 ) ] # concat all of c c = self.mlp_o(torch.cat(c, dim=1)) return c, w
[docs]class AttForward(torch.nn.Module): """Forward attention module. Reference: Forward attention in sequence-to-sequence acoustic modeling for speech synthesis (https://arxiv.org/pdf/1807.06736.pdf) :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution """ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): super(AttForward, self).__init__() self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward( self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3, ): """Calculate AttForward forward propagation. :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: attention weights of previous step :param float scaling: scaling parameter before applying softmax :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) if att_prev is None: # initial attention will be [1, 0, 0, ...] att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) att_prev[:, 0] = 1.0 # att_prev: utt x frame -> utt x 1 x 1 x frame # -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv) ).squeeze(2) # NOTE: consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) # apply monotonic attention constraint (mainly for TTS) if last_attended_idx is not None: e = _apply_attention_constraint( e, last_attended_idx, backward_window, forward_window ) w = F.softmax(scaling * e, dim=1) # forward attention att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] w = (att_prev + att_prev_shift) * w # NOTE: clamp is needed to avoid nan gradient w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1) return c, w
[docs]class AttForwardTA(torch.nn.Module): """Forward attention with transition agent module. Reference: Forward attention in sequence-to-sequence acoustic modeling for speech synthesis (https://arxiv.org/pdf/1807.06736.pdf) :param int eunits: # units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param int odim: output dimension """ def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim): super(AttForwardTA, self).__init__() self.mlp_enc = torch.nn.Linear(eunits, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1) self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eunits = eunits self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.trans_agent_prob = 0.5
[docs] def reset(self): self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.trans_agent_prob = 0.5
[docs] def forward( self, enc_hs_pad, enc_hs_len, dec_z, att_prev, out_prev, scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3, ): """Calculate AttForwardTA forward propagation. :param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B, dunits) :param torch.Tensor att_prev: attention weights of previous step :param torch.Tensor out_prev: decoder outputs of previous step (B, odim) :param float scaling: scaling parameter before applying softmax :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, dunits) :rtype: torch.Tensor :return: previous attention weights (B, Tmax) :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) if att_prev is None: # initial attention will be [1, 0, 0, ...] att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) att_prev[:, 0] = 1.0 # att_prev: utt x frame -> utt x 1 x 1 x frame # -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) # apply monotonic attention constraint (mainly for TTS) if last_attended_idx is not None: e = _apply_attention_constraint( e, last_attended_idx, backward_window, forward_window ) w = F.softmax(scaling * e, dim=1) # forward attention att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] w = ( self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift ) * w # NOTE: clamp is needed to avoid nan gradient w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) # update transition agent prob self.trans_agent_prob = torch.sigmoid( self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)) ) return c, w
[docs]def att_for(args, num_att=1, han_mode=False): """Instantiates an attention module given the program arguments :param Namespace args: The arguments :param int num_att: number of attention modules (in multi-speaker case, it can be 2 or more) :param bool han_mode: switch on/off mode of hierarchical attention network (HAN) :rtype torch.nn.Module :return: The attention module """ att_list = torch.nn.ModuleList() num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility aheads = getattr(args, "aheads", None) awin = getattr(args, "awin", None) aconv_chans = getattr(args, "aconv_chans", None) aconv_filts = getattr(args, "aconv_filts", None) if num_encs == 1: for i in range(num_att): att = initial_att( args.atype, args.eprojs, args.dunits, aheads, args.adim, awin, aconv_chans, aconv_filts, ) att_list.append(att) elif num_encs > 1: # no multi-speaker mode if han_mode: att = initial_att( args.han_type, args.eprojs, args.dunits, args.han_heads, args.han_dim, args.han_win, args.han_conv_chans, args.han_conv_filts, han_mode=True, ) return att else: att_list = torch.nn.ModuleList() for idx in range(num_encs): att = initial_att( args.atype[idx], args.eprojs, args.dunits, aheads[idx], args.adim[idx], awin[idx], aconv_chans[idx], aconv_filts[idx], ) att_list.append(att) else: raise ValueError( "Number of encoders needs to be more than one. {}".format(num_encs) ) return att_list
[docs]def initial_att( atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False ): """Instantiates a single attention module :param str atype: attention type :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int adim: attention dimension :param int awin: attention window size :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention :return: The attention module """ if atype == "noatt": att = NoAtt() elif atype == "dot": att = AttDot(eprojs, dunits, adim, han_mode) elif atype == "add": att = AttAdd(eprojs, dunits, adim, han_mode) elif atype == "location": att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) elif atype == "location2d": att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode) elif atype == "location_recurrent": att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) elif atype == "coverage": att = AttCov(eprojs, dunits, adim, han_mode) elif atype == "coverage_location": att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) elif atype == "multi_head_dot": att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode) elif atype == "multi_head_add": att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode) elif atype == "multi_head_loc": att = AttMultiHeadLoc( eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode ) elif atype == "multi_head_multi_res_loc": att = AttMultiHeadMultiResLoc( eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode ) return att
[docs]def att_to_numpy(att_ws, att): """Converts attention weights to a numpy array given the attention :param list att_ws: The attention weights :param torch.nn.Module att: The attention :rtype: np.ndarray :return: The numpy array of the attention weights """ # convert to numpy array with the shape (B, Lmax, Tmax) if isinstance(att, AttLoc2D): # att_ws => list of previous concate attentions att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy() elif isinstance(att, (AttCov, AttCovLoc)): # att_ws => list of list of previous attentions att_ws = ( torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() ) elif isinstance(att, AttLocRec): # att_ws => list of tuple of attention and hidden states att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy() elif isinstance( att, (AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc), ): # att_ws => list of list of each head attention n_heads = len(att_ws[0]) att_ws_sorted_by_head = [] for h in range(n_heads): att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1) att_ws_sorted_by_head += [att_ws_head] att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy() else: # att_ws => list of attentions att_ws = torch.stack(att_ws, dim=1).cpu().numpy() return att_ws
def _apply_dynamic_filter(p, last_attended_idx, backward_window=1, forward_window=3): """Apply dynamic filter. This function apply the dynamic filter introduced in `Singing-Tacotron: Global Duration Control Attention and Dynamic Filter for End-to-end Singing Voice Synthesis`_. Args: p (Tensor): probability before applying softmax (1, T). last_attended_idx (int): The index of the inputs of the last attended [0, T]. backward_window (int, optional): Backward window size in dynamic filter. forward_window (int, optional): Forward window size in dynamic filter. Returns: Tensor: Dynamic filtered probability (1, T). .. _`Singing-Tacotron: Global Duration Control Attention and Dynamic Filter for End-to-end Singing Voice Synthesis`: https://arxiv.org/pdf/2202.07907v1.pdf """ if p.size(0) != 1: raise NotImplementedError("Batch dynamic filter is not yet supported.") backward_idx = last_attended_idx - backward_window forward_idx = last_attended_idx + forward_window if backward_idx > 0: p[:, :backward_idx] = 0 if forward_idx < p.size(1): p[:, forward_idx:] = 0 return p
[docs]class GDCAttLoc(torch.nn.Module): """Global duration control attention module. Reference: Singing-Tacotron: Global Duration Control Attention and Dynamic Filter for End-to-end Singing Voice Synthesis (https://arxiv.org/abs/2202.07907) :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h """ def __init__( self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False ): super(GDCAttLoc, self).__init__() self.pt_zero_linear = torch.nn.Linear(att_dim, 1) self.mlp_enc = torch.nn.Linear(eprojs, att_dim) self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) self.loc_conv = torch.nn.Conv2d( 1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False, ) self.gvec = torch.nn.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None self.han_mode = han_mode
[docs] def reset(self): """reset states""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.mask = None
[docs] def forward( self, enc_hs_pad, enc_hs_len, trans_token, dec_z, att_prev, scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3, ): """Calcualte AttLoc forward propagation. :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor trans_token: Global transition token for duration (B x T_max x 1) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: previous attention weight (B x T_max) :param float scaling: scaling parameter before applying softmax :param torch.Tensor forward_window: forward window size when constraining attention :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor """ batch = len(enc_hs_pad) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None or self.han_mode: self.enc_h = enc_hs_pad # utt x frame x hdim self.h_length = self.enc_h.size(1) # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h) if dec_z is None: dec_z = enc_hs_pad.new_zeros(batch, self.dunits) else: dec_z = dec_z.view(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) att_prev[:, 0] = 1.0 # att_prev: utt x frame -> utt x 1 x 1 x frame # -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = att_conv.squeeze(2).transpose(1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) # dot with gvec # utt x frame x att_dim -> utt x frame e = self.gvec( torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) ).squeeze(2) # NOTE: consider zero padding when compute w. if self.mask is None: self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) e.masked_fill_(self.mask, -float("inf")) w = F.softmax(scaling * e, dim=1) # dynamic filter if last_attended_idx is not None: att_prev = _apply_dynamic_filter( att_prev, last_attended_idx, backward_window, forward_window ) # GDCA attention att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] trans_token = trans_token.squeeze(-1) trans_token_shift = F.pad(trans_token, (1, 0))[:, :-1] w = ((1 - trans_token_shift) * att_prev_shift + trans_token * att_prev) * w # NOTE: clamp is needed to avoid nan gradient w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) # weighted sum over flames # utt x hdim # NOTE use bmm instead of sum(*) c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) return c, w