espnet.nets.pytorch_backend.transformer.decoder_layer.DecoderLayer
About 1 min
espnet.nets.pytorch_backend.transformer.decoder_layer.DecoderLayer
class espnet.nets.pytorch_backend.transformer.decoder_layer.DecoderLayer(size, self_attn, src_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False, sequential_attn=None)
Bases: Module
Single decoder layer module.
- Parameters:
- size (int) – Input dimension.
- self_attn (torch.nn.Module) – Self-attention module instance. MultiHeadedAttention instance can be used as the argument.
- src_attn (torch.nn.Module) – Self-attention module instance. MultiHeadedAttention instance can be used as the argument.
- feed_forward (torch.nn.Module) – Feed-forward module instance. PositionwiseFeedForward, MultiLayeredConv1d, or Conv1dLinear instance can be used as the argument.
- dropout_rate (float) – Dropout rate.
- normalize_before (bool) – Whether to use layer_norm before the first block.
- concat_after (bool) – Whether to concat attention layer’s input and output. if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x)
- sequential_attn (bool) – computes attn first on pre_x then on x, thereby attending to two sources in sequence.
Construct an DecoderLayer object.
forward(tgt, tgt_mask, memory, memory_mask, cache=None, pre_memory=None, pre_memory_mask=None)
Compute decoded features.
- Parameters:
- tgt (torch.Tensor) – Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor) – Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor) – Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor) – Encoded memory mask (#batch, 1, maxlen_in).
- cache (List *[*torch.Tensor ]) – List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size).
- pre_memory (torch.Tensor) – Encoded memory (#batch, maxlen_in, size).
- pre_memory_mask (torch.Tensor) – Encoded memory mask (#batch, maxlen_in).
- Returns: Output tensor(#batch, maxlen_out, size). torch.Tensor: Mask for output tensor (#batch, maxlen_out). torch.Tensor: Encoded memory (#batch, maxlen_in, size). torch.Tensor: Encoded memory mask (#batch, maxlen_in).
- Return type: torch.Tensor
forward_partially_AR(tgt, tgt_mask, tgt_lengths, memory, memory_mask, cache=None)