Source code for espnet.nets.scorers.length_bonus

"""Length bonus module."""

from typing import Any, List, Tuple

import torch

from espnet.nets.scorer_interface import BatchScorerInterface


[docs]class LengthBonus(BatchScorerInterface): """Length bonus in beam search.""" def __init__(self, n_vocab: int): """Initialize class. Args: n_vocab (int): The number of tokens in vocabulary for beam search """ self.n = n_vocab
[docs] def score(self, y, state, x): """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): 2D encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (n_vocab) and None """ return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
[docs] def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ return ( torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( ys.shape[0], self.n ), None, )