espnet.nets.pytorch_backend.rnn.attentions.AttLocRec
Less than 1 minute
espnet.nets.pytorch_backend.rnn.attentions.AttLocRec
class espnet.nets.pytorch_backend.rnn.attentions.AttLocRec(eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False)
Bases: Module
location-aware recurrent attention
This attention is an extended version of location aware attention. With the use of RNN, it take the effect of the history of attention weights into account.
- Parameters:
- eprojs (int) – # projection-units of encoder
- dunits (int) – # units of decoder
- att_dim (int) – attention dimension
- aconv_chans (int) – # channels of attention convolution
- aconv_filts (int) – filter size of attention convolution
- han_mode (bool) – flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0)
AttLocRec forward
- Parameters:
- enc_hs_pad (torch.Tensor) – padded encoder hidden state (B x T_max x D_enc)
- enc_hs_len (list) – padded encoder hidden state length (B)
- dec_z (torch.Tensor) – decoder hidden state (B x D_dec)
- att_prev_states (tuple) – previous attention weight and lstm states ((B, T_max), ((B, att_dim), (B, att_dim)))
- scaling (float) – scaling parameter before applying softmax
- Returns: attention weighted encoder state (B, D_enc)
- Return type: torch.Tensor
- Returns: previous attention weights and lstm states (w, (hx, cx)) ((B, T_max), ((B, att_dim), (B, att_dim)))
- Return type: tuple
reset()
reset states