Source code for espnet.nets.scorer_interface

"""Scorer interface module."""

import warnings
from typing import Any, List, Tuple

import torch


[docs]class ScorerInterface: """Scorer interface for beam search. The scorer performs scoring of the all tokens in vocabulary. Examples: * Search heuristics * :class:`espnet.nets.scorers.length_bonus.LengthBonus` * Decoder networks of the sequence-to-sequence models * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` * Neural language models * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` """
[docs] def init_state(self, x: torch.Tensor) -> Any: """Get an initial state for decoding (optional). Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ return None
[docs] def select_state(self, state: Any, i: int, new_id: int = None) -> Any: """Select state with relative ids in the main beam search. Args: state: Decoder state for prefix tokens i (int): Index to select a state in the main beam search new_id (int): New label index to select a state if necessary Returns: state: pruned state """ return None if state is None else state[i]
[docs] def score( self, y: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token (required). Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): The encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of scores for next token that has a shape of `(n_vocab)` and next state for ys """ raise NotImplementedError
[docs] def final_score(self, state: Any) -> float: """Score eos (optional). Args: state: Scorer state for prefix tokens Returns: float: final score """ return 0.0
[docs]class BatchScorerInterface(ScorerInterface): """Batch scorer interface."""
[docs] def batch_init_state(self, x: torch.Tensor) -> Any: """Get an initial state for decoding (optional). Args: x (torch.Tensor): The encoded feature tensor Returns: initial state """ return self.init_state(x)
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ warnings.warn( "{} batch score is implemented through for loop not parallelized".format( self.__class__.__name__ ) ) scores = list() outstates = list() for i, (y, state, x) in enumerate(zip(ys, states, xs)): score, outstate = self.score(y, state, x) outstates.append(outstate) scores.append(score) scores = torch.cat(scores, 0).view(ys.shape[0], -1) return scores, outstates
[docs]class PartialScorerInterface(ScorerInterface): """Partial scorer interface for beam search. The partial scorer performs scoring when non-partial scorer finished scoring, and receives pre-pruned next tokens to score because it is too heavy to score all the tokens. Examples: * Prefix search for connectionist-temporal-classification models * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` """
[docs] def score_partial( self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: """Score new token (required). Args: y (torch.Tensor): 1D prefix token next_tokens (torch.Tensor): torch.int64 next token to score state: decoder state for prefix tokens x (torch.Tensor): The encoder feature that generates ys Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)` and next state for ys """ raise NotImplementedError
[docs]class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): """Batch partial scorer interface for beam search."""
[docs] def batch_score_partial( self, ys: torch.Tensor, next_tokens: torch.Tensor, states: List[Any], xs: torch.Tensor, ) -> Tuple[torch.Tensor, Any]: """Score new token (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, Any]: Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` and next states for ys """ raise NotImplementedError