Source code for espnet.nets.transducer_decoder_interface

"""Transducer decoder interface module."""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch


[docs]@dataclass class Hypothesis: """Default hypothesis definition for Transducer search algorithms.""" score: float yseq: List[int] dec_state: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]], torch.Tensor, ] lm_state: Union[Dict[str, Any], List[Any]] = None
[docs]@dataclass class ExtendedHypothesis(Hypothesis): """Extended hypothesis definition for NSC beam search and mAES.""" dec_out: List[torch.Tensor] = None lm_scores: torch.Tensor = None
[docs]class TransducerDecoderInterface: """Decoder interface for Transducer models."""
[docs] def init_state( self, batch_size: int, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ]: """Initialize decoder states. Args: batch_size: Batch size. Returns: state: Initial decoder hidden states. """ raise NotImplementedError("init_state(...) is not implemented")
[docs] def score( self, hyp: Hypothesis, cache: Dict[str, Any], ) -> Tuple[ torch.Tensor, Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], torch.Tensor, ]: """One-step forward hypothesis. Args: hyp: Hypothesis. cache: Pairs of (dec_out, dec_state) for each token sequence. (key) Returns: dec_out: Decoder output sequence. new_state: Decoder hidden states. lm_tokens: Label ID for LM. """ raise NotImplementedError("score(...) is not implemented")
[docs] def batch_score( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]], dec_states: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], cache: Dict[str, Any], use_lm: bool, ) -> Tuple[ torch.Tensor, Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], torch.Tensor, ]: """One-step forward hypotheses. Args: hyps: Hypotheses. dec_states: Decoder hidden states. cache: Pairs of (dec_out, dec_states) for each label sequence. (key) use_lm: Whether to compute label ID sequences for LM. Returns: dec_out: Decoder output sequences. dec_states: Decoder hidden states. lm_labels: Label ID sequences for LM. """ raise NotImplementedError("batch_score(...) is not implemented")
[docs] def select_state( self, batch_states: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor] ], idx: int, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ]: """Get specified ID state from decoder hidden states. Args: batch_states: Decoder hidden states. idx: State ID to extract. Returns: state_idx: Decoder hidden state for given ID. """ raise NotImplementedError("select_state(...) is not implemented")
[docs] def create_batch_states( self, states: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ], new_states: List[ Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]], ] ], l_tokens: List[List[int]], ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] ]: """Create decoder hidden states. Args: batch_states: Batch of decoder states l_states: List of decoder states l_tokens: List of token sequences for input batch Returns: batch_states: Batch of decoder states """ raise NotImplementedError("create_batch_states(...) is not implemented")