Source code for espnet.nets.pytorch_backend.transducer.transformer_decoder_layer
"""Transformer decoder layer definition for custom Transducer model."""
from typing import Optional
import torch
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
[docs]class TransformerDecoderLayer(torch.nn.Module):
"""Transformer decoder layer module for custom Transducer model.
Args:
hdim: Hidden dimension.
self_attention: Self-attention module.
feed_forward: Feed forward module.
dropout_rate: Dropout rate.
"""
def __init__(
self,
hdim: int,
self_attention: MultiHeadedAttention,
feed_forward: PositionwiseFeedForward,
dropout_rate: float,
):
"""Construct an DecoderLayer object."""
super().__init__()
self.self_attention = self_attention
self.feed_forward = feed_forward
self.norm1 = LayerNorm(hdim)
self.norm2 = LayerNorm(hdim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.hdim = hdim
[docs] def forward(
self,
sequence: torch.Tensor,
mask: torch.Tensor,
cache: Optional[torch.Tensor] = None,
):
"""Compute previous decoder output sequences.
Args:
sequence: Transformer input sequences. (B, U, D_dec)
mask: Transformer intput mask sequences. (B, U)
cache: Cached decoder output sequences. (B, (U - 1), D_dec)
Returns:
sequence: Transformer output sequences. (B, U, D_dec)
mask: Transformer output mask sequences. (B, U)
"""
residual = sequence
sequence = self.norm1(sequence)
if cache is None:
sequence_q = sequence
else:
batch = sequence.shape[0]
prev_len = sequence.shape[1] - 1
assert cache.shape == (
batch,
prev_len,
self.hdim,
), f"{cache.shape} == {(batch, prev_len, self.hdim)}"
sequence_q = sequence[:, -1:, :]
residual = residual[:, -1:, :]
if mask is not None:
mask = mask[:, -1:, :]
sequence = residual + self.dropout(
self.self_attention(sequence_q, sequence, sequence, mask)
)
residual = sequence
sequence = self.norm2(sequence)
sequence = residual + self.dropout(self.feed_forward(sequence))
if cache is not None:
sequence = torch.cat([cache, sequence], dim=1)
return sequence, mask