Source code for espnet2.asr_transducer.decoder.mega_decoder

"""MEGA decoder definition for Transducer models."""

import math
from typing import Dict, List, Optional, Tuple

import torch
from typeguard import typechecked

from espnet2.asr_transducer.activation import get_activation
from espnet2.asr_transducer.beam_search_transducer import Hypothesis
from espnet2.asr_transducer.decoder.abs_decoder import AbsDecoder
from espnet2.asr_transducer.decoder.blocks.mega import MEGA
from espnet2.asr_transducer.decoder.modules.mega.feed_forward import (
    NormalizedPositionwiseFeedForward,
)
from espnet2.asr_transducer.normalization import get_normalization


[docs]class MEGADecoder(AbsDecoder): """MEGA decoder module. Based on https://arxiv.org/pdf/2209.10655.pdf. Args: vocab_size: Vocabulary size. block_size: Input/Output size. linear_size: NormalizedPositionwiseFeedForward hidden size. qk_size: Shared query and key size for attention module. v_size: Value size for attention module. num_heads: Number of EMA heads. rel_pos_bias: Type of relative position bias in attention module. max_positions: Maximum number of position for RelativePositionBias. truncation_length: Maximum length for truncation in EMA module. normalization_type: Normalization layer type. normalization_args: Normalization layer arguments. activation_type: Activation function type. activation_args: Activation function arguments. chunk_size: Chunk size for attention computation (-1 = full context). num_blocks: Number of MEGA blocks. dropout_rate: Dropout rate for MEGA internal modules. embed_dropout_rate: Dropout rate for embedding layer. att_dropout_rate: Dropout rate for the attention module. ema_dropout_rate: Dropout rate for the EMA module. ffn_dropout_rate: Dropout rate for the feed-forward module. embed_pad: Embedding padding symbol ID. """ @typechecked def __init__( self, vocab_size: int, block_size: int = 512, linear_size: int = 1024, qk_size: int = 128, v_size: int = 1024, num_heads: int = 4, rel_pos_bias_type: str = "simple", max_positions: int = 2048, truncation_length: Optional[int] = None, normalization_type: str = "layer_norm", normalization_args: Dict = {}, activation_type: str = "swish", activation_args: Dict = {}, chunk_size: int = -1, num_blocks: int = 4, dropout_rate: float = 0.0, embed_dropout_rate: float = 0.0, att_dropout_rate: float = 0.0, ema_dropout_rate: float = 0.0, ffn_dropout_rate: float = 0.0, embed_pad: int = 0, ) -> None: """Construct a MEGADecoder object.""" super().__init__() self.embed = torch.nn.Embedding(vocab_size, block_size, padding_idx=embed_pad) self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) activation = get_activation(activation_type, **activation_args) norm_class, norm_args = get_normalization( normalization_type, **normalization_args ) self.mega_blocks = torch.nn.ModuleList( [ torch.nn.ModuleList( [ MEGA( block_size, num_heads=num_heads, qk_size=qk_size, v_size=v_size, activation=activation, normalization=norm_class(block_size, **norm_args), rel_pos_bias_type=rel_pos_bias_type, max_positions=max_positions, truncation_length=truncation_length, chunk_size=chunk_size, dropout_rate=dropout_rate, att_dropout_rate=att_dropout_rate, ema_dropout_rate=ema_dropout_rate, ), NormalizedPositionwiseFeedForward( block_size, linear_size, normalization=norm_class(block_size, **norm_args), activation=activation, dropout_rate=ffn_dropout_rate, ), ] ) for _ in range(num_blocks) ] ) self.final_norm = norm_class(block_size, **norm_args) self.vocab_size = vocab_size self.output_size = block_size self.chunk_size = chunk_size self.mega_num_heads = num_heads self.mega_att_k_size = qk_size self.mega_att_v_size = v_size self.mega_ema_size = block_size self.mega_ema_num_heads = num_heads self.pad_idx = embed_pad self.num_blocks = num_blocks self.score_cache = {} self.device = next(self.parameters()).device
[docs] def forward(self, labels: torch.Tensor) -> torch.Tensor: """Encode source label sequences. Args: labels: Decoder input sequences. (B, L) Returns: out: Decoder output sequences. (B, U, D_dec) """ batch, length = labels.size() if 0 < self.chunk_size < length and length % self.chunk_size != 0: num_paddings = ( math.ceil(length / self.chunk_size) * self.chunk_size - length ) labels = torch.nn.functional.pad( labels, (0, num_paddings), value=self.pad_idx ) else: num_paddings = 0 mask = (labels == self.pad_idx).unsqueeze(1) mask[..., 0] = False mask = mask.to(device=labels.device, dtype=torch.bool) _length = self.chunk_size if 0 < self.chunk_size < length else length attn_mask = torch.ones( (_length, _length), device=labels.device, dtype=torch.bool ) attn_mask = torch.triu(attn_mask, 1, out=attn_mask).unsqueeze(0) x = self.dropout_embed(self.embed(labels)).transpose(0, 1) for idx, (mega_block, nffn) in enumerate(self.mega_blocks): x, _ = mega_block(x, mask=mask, attn_mask=attn_mask) x = nffn(x) out = self.final_norm(x).transpose(0, 1) if num_paddings > 0: out = out[:, :length, :] return out
[docs] def inference( self, labels: torch.Tensor, states: List[Dict[str, torch.Tensor]], ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: """Encode source label sequences. Args: labels: Decoder input sequences. (B, L) states: Decoder hidden states. [B x Dict] Returns: out: Decoder output sequences. (B, U, D_dec) new_states: Decoder hidden states. [B x Dict] """ x = self.embed(labels).transpose(0, 1) new_states = [] for idx, (mega_block, nffn) in enumerate(self.mega_blocks): x, new_state = mega_block(x, state=states[idx]) x = nffn(x) new_states.append(new_state) out = self.final_norm(x).transpose(0, 1) return out, new_states
[docs] def set_device(self, device: torch.device) -> None: """Set GPU device to use. Args: device: Device ID. """ self.device = device
[docs] def score( self, label_sequence: List[int], states: List[Dict[str, torch.Tensor]], ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: """One-step forward hypothesis. Args: label_sequence: Current label sequence. states: Decoder hidden states. (??) Returns: : Decoder output sequence. (D_dec) states: Decoder hidden states. (??) """ str_labels = "_".join(map(str, label_sequence)) if str_labels in self.score_cache: out, states = self.score_cache[str_labels] else: label = torch.full( (1, 1), label_sequence[-1], dtype=torch.long, device=self.device ) out, states = self.inference(label, states=states) self.score_cache[str_labels] = (out, states) return out[0], states
[docs] def batch_score( self, hyps: List[Hypothesis] ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: """One-step forward hypotheses. Args: hyps: Hypotheses. Returns: out: states: """ labels = torch.tensor( [[h.yseq[-1]] for h in hyps], dtype=torch.long, device=self.device ) states = self.create_batch_states([h.dec_state for h in hyps]) out, states = self.inference(labels, states=states) return out.squeeze(1), states
[docs] def init_state(self, batch_size: int = 0) -> List[Dict[str, torch.Tensor]]: """Initialize MEGADecoder states. Args: batch_size: Batch size. Returns: states: Decoder hidden states. [N x Dict] """ return [ { "ema_state": torch.zeros( (self.output_size, self.mega_ema_num_heads), device=self.device ), "prev_key": torch.zeros( (1, 1, self.mega_att_k_size), device=self.device ), "prev_value": torch.zeros( (1, 1, self.mega_att_v_size), device=self.device ), } for _ in range(self.num_blocks) ]
[docs] def select_state( self, states: List[Dict[str, torch.Tensor]], idx: int, ) -> List[Dict[str, torch.Tensor]]: """Select ID state from batch of decoder hidden states. Args: states: Decoder hidden states. [N x Dict] Returns: : Decoder hidden states for given ID. [N x Dict] """ return [ { "ema_state": states[n_b]["ema_state"][idx], "prev_key": states[n_b]["prev_key"][idx], "prev_value": states[n_b]["prev_value"][idx], } for n_b in range(self.num_blocks) ]
[docs] def stack_qk_states( self, state_list: List[torch.Tensor], dim: int ) -> List[torch.Tensor]: """Stack query or key states with different lengths. Args: state_list: List of query or key states. Returns: new_state: Query/Key state. """ max_len = max([(state.size(0)) for state in state_list]) new_state = torch.zeros((len(state_list), max_len, dim)) for idx, state in enumerate(state_list): new_state[idx, -state.size(0) :, :] = state return new_state
[docs] def create_batch_states( self, new_states: List[List[Dict[str, torch.Tensor]]], ) -> List[Dict[str, torch.Tensor]]: """Create batch of decoder hidden states given a list of new states. Args: new_states: Decoder hidden states. [B x [N x Dict]] Returns: : Decoder hidden states. [N x Dict] """ return [ { "ema_state": torch.stack( [state[n_b]["ema_state"] for state in new_states] ), "prev_key": self.stack_qk_states( [state[n_b]["prev_key"] for state in new_states], self.mega_att_k_size, ), "prev_value": self.stack_qk_states( [state[n_b]["prev_value"] for state in new_states], self.mega_att_v_size, ), } for n_b in range(self.num_blocks) ]