espnet.nets.pytorch_backend.rnn.attentions.AttForwardTA
Less than 1 minute
espnet.nets.pytorch_backend.rnn.attentions.AttForwardTA
class espnet.nets.pytorch_backend.rnn.attentions.AttForwardTA(eunits, dunits, att_dim, aconv_chans, aconv_filts, odim)
Bases: Module
Forward attention with transition agent module.
Reference: Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
- Parameters:
- eunits (int) – # units of encoder
- dunits (int) – # units of decoder
- att_dim (int) – attention dimension
- aconv_chans (int) – # channels of attention convolution
- aconv_filts (int) – filter size of attention convolution
- odim (int) – output dimension
Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(enc_hs_pad, enc_hs_len, dec_z, att_prev, out_prev, scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3)
Calculate AttForwardTA forward propagation.
- Parameters:
- enc_hs_pad (torch.Tensor) – padded encoder hidden state (B, Tmax, eunits)
- enc_hs_len (list) – padded encoder hidden state length (B)
- dec_z (torch.Tensor) – decoder hidden state (B, dunits)
- att_prev (torch.Tensor) – attention weights of previous step
- out_prev (torch.Tensor) – decoder outputs of previous step (B, odim)
- scaling (float) – scaling parameter before applying softmax
- last_attended_idx (int) – index of the inputs of the last attended
- backward_window (int) – backward window size in attention constraint
- forward_window (int) – forward window size in attetion constraint
- Returns: attention weighted encoder state (B, dunits)
- Return type: torch.Tensor
- Returns: previous attention weights (B, Tmax)
- Return type: torch.Tensor
reset()