"""Custom decoder definition for Transducer model."""

from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.utils import (
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.transducer_decoder_interface import (

[docs]class CustomDecoder(TransducerDecoderInterface, torch.nn.Module): """Custom decoder module for Transducer model. Args: odim: Output dimension. dec_arch: Decoder block architecture (type and parameters). input_layer: Input layer type. repeat_block: Number of times dec_arch is repeated. joint_activation_type: Type of activation for joint network. positional_encoding_type: Positional encoding type. positionwise_layer_type: Positionwise layer type. positionwise_activation_type: Positionwise activation type. input_layer_dropout_rate: Dropout rate for input layer. blank_id: Blank symbol ID. """ def __init__( self, odim: int, dec_arch: List, input_layer: str = "embed", repeat_block: int = 0, joint_activation_type: str = "tanh", positional_encoding_type: str = "abs_pos", positionwise_layer_type: str = "linear", positionwise_activation_type: str = "relu", input_layer_dropout_rate: float = 0.0, blank_id: int = 0, ): """Construct a CustomDecoder object.""" torch.nn.Module.__init__(self) self.embed, self.decoders, ddim, _ = build_blocks( "decoder", odim, input_layer, dec_arch, repeat_block=repeat_block, positional_encoding_type=positional_encoding_type, positionwise_layer_type=positionwise_layer_type, positionwise_activation_type=positionwise_activation_type, input_layer_dropout_rate=input_layer_dropout_rate, padding_idx=blank_id, ) self.after_norm = LayerNorm(ddim) self.dlayers = len(self.decoders) self.dunits = ddim self.odim = odim self.blank_id = blank_id
[docs] def set_device(self, device: torch.device): """Set GPU device to use. Args: device: Device ID. """ self.device = device
[docs] def init_state( self, batch_size: Optional[int] = None, ) -> List[Optional[torch.Tensor]]: """Initialize decoder states. Args: batch_size: Batch size. Returns: state: Initial decoder hidden states. [N x None] """ state = [None] * self.dlayers return state
[docs] def forward( self, dec_input: torch.Tensor, dec_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Encode label ID sequences. Args: dec_input: Label ID sequences. (B, U) dec_mask: Label mask sequences. (B, U) Return: dec_output: Decoder output sequences. (B, U, D_dec) dec_output_mask: Mask of decoder output sequences. (B, U) """ dec_input = self.embed(dec_input) dec_output, dec_mask = self.decoders(dec_input, dec_mask) dec_output = self.after_norm(dec_output) return dec_output, dec_mask
[docs] def score( self, hyp: Hypothesis, cache: Dict[str, Any] ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """One-step forward hypothesis. Args: hyp: Hypothesis. cache: Pairs of (dec_out, dec_state) for each label sequence. (key) Returns: dec_out: Decoder output sequence. (1, D_dec) dec_state: Decoder hidden states. [N x (1, U, D_dec)] lm_label: Label ID for LM. (1,) """ labels = torch.tensor([hyp.yseq], device=self.device) lm_label = labels[:, -1] str_labels = "_".join(list(map(str, hyp.yseq))) if str_labels in cache: dec_out, dec_state = cache[str_labels] else: dec_out_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0) new_state = check_state(hyp.dec_state, (labels.size(1) - 1), self.blank_id) dec_out = self.embed(labels) dec_state = [] for s, decoder in zip(new_state, self.decoders): dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s) dec_state.append(dec_out) dec_out = self.after_norm(dec_out[:, -1]) cache[str_labels] = (dec_out, dec_state) return dec_out[0], dec_state, lm_label
[docs] def batch_score( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]], dec_states: List[Optional[torch.Tensor]], cache: Dict[str, Any], use_lm: bool, ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """One-step forward hypotheses. Args: hyps: Hypotheses. dec_states: Decoder hidden states. [N x (B, U, D_dec)] cache: Pairs of (h_dec, dec_states) for each label sequences. (keys) use_lm: Whether to compute label ID sequences for LM. Returns: dec_out: Decoder output sequences. (B, D_dec) dec_states: Decoder hidden states. [N x (B, U, D_dec)] lm_labels: Label ID sequences for LM. (B,) """ final_batch = len(hyps) process = [] done = [None] * final_batch for i, hyp in enumerate(hyps): str_labels = "_".join(list(map(str, hyp.yseq))) if str_labels in cache: done[i] = cache[str_labels] else: process.append((str_labels, hyp.yseq, hyp.dec_state)) if process: labels = pad_sequence([p[1] for p in process], self.blank_id) labels = torch.LongTensor(labels, device=self.device) p_dec_states = self.create_batch_states( self.init_state(), [p[2] for p in process], labels, ) dec_out = self.embed(labels) dec_out_mask = ( subsequent_mask(labels.size(-1)) .unsqueeze_(0) .expand(len(process), -1, -1) ) new_states = [] for s, decoder in zip(p_dec_states, self.decoders): dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s) new_states.append(dec_out) dec_out = self.after_norm(dec_out[:, -1]) j = 0 for i in range(final_batch): if done[i] is None: state = self.select_state(new_states, j) done[i] = (dec_out[j], state) cache[process[j][0]] = (dec_out[j], state) j += 1 dec_out = torch.stack([d[0] for d in done]) dec_states = self.create_batch_states( dec_states, [d[1] for d in done], [[0] + h.yseq for h in hyps] ) if use_lm: lm_labels = torch.LongTensor( [hyp.yseq[-1] for hyp in hyps], device=self.device ) return dec_out, dec_states, lm_labels return dec_out, dec_states, None
[docs] def select_state( self, states: List[Optional[torch.Tensor]], idx: int ) -> List[Optional[torch.Tensor]]: """Get specified ID state from decoder hidden states. Args: states: Decoder hidden states. [N x (B, U, D_dec)] idx: State ID to extract. Returns: state_idx: Decoder hidden state for given ID. [N x (1, U, D_dec)] """ if states[0] is None: return states state_idx = [states[layer][idx] for layer in range(self.dlayers)] return state_idx
[docs] def create_batch_states( self, states: List[Optional[torch.Tensor]], new_states: List[Optional[torch.Tensor]], check_list: List[List[int]], ) -> List[Optional[torch.Tensor]]: """Create decoder hidden states sequences. Args: states: Decoder hidden states. [N x (B, U, D_dec)] new_states: Decoder hidden states. [B x [N x (1, U, D_dec)]] check_list: Label ID sequences. Returns: states: New decoder hidden states. [N x (B, U, D_dec)] """ if new_states[0][0] is None: return states max_len = max(len(elem) for elem in check_list) - 1 for layer in range(self.dlayers): states[layer] = check_batch_states( [s[layer] for s in new_states], max_len, self.blank_id ) return states