Source code for espnet.nets.batch_beam_search

"""Parallel beam search module."""

import logging
from typing import Any, Dict, List, NamedTuple, Tuple

import torch
from packaging.version import parse as V
from torch.nn.utils.rnn import pad_sequence

from espnet.nets.beam_search import BeamSearch, Hypothesis

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

logger = logging.getLogger(__name__)


[docs]class BatchHypothesis(NamedTuple): """Batchfied/Vectorized hypothesis data type.""" yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen) score: torch.Tensor = torch.tensor([]) # (batch,) length: torch.Tensor = torch.tensor([]) # (batch,) scores: Dict[str, torch.Tensor] = dict() # values: (batch,) states: Dict[str, Dict] = dict() hs: List[torch.Tensor] = [] # (batch, maxlen, adim) def __len__(self) -> int: """Return a batch size.""" return len(self.length)
[docs]class BatchBeamSearch(BeamSearch): """Batch beam search implementation."""
[docs] def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis: """Convert list to batch.""" if len(hyps) == 0: return BatchHypothesis() if self.return_hs: hs = [h.hs for h in hyps] else: hs = [] return BatchHypothesis( yseq=pad_sequence( [h.yseq for h in hyps], batch_first=True, padding_value=self.eos ), length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64), score=torch.tensor([h.score for h in hyps]), scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers}, states={k: [h.states[k] for h in hyps] for k in self.scorers}, hs=hs, )
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis: if self.return_hs: hs = [hyps.hs[i] for i in ids] else: hs = [] return BatchHypothesis( yseq=hyps.yseq[ids], score=hyps.score[ids], length=hyps.length[ids], scores={k: v[ids] for k, v in hyps.scores.items()}, states={ k: [self.scorers[k].select_state(v, i) for i in ids] for k, v in hyps.states.items() }, hs=hs, ) def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis: return Hypothesis( yseq=hyps.yseq[i, : hyps.length[i]], score=hyps.score[i], scores={k: v[i] for k, v in hyps.scores.items()}, states={ k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items() }, hs=hyps.hs[i] if self.return_hs else [], )
[docs] def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]: """Revert batch to list.""" return [ Hypothesis( yseq=batch_hyps.yseq[i][: batch_hyps.length[i]], score=batch_hyps.score[i], scores={k: batch_hyps.scores[k][i] for k in self.scorers}, states={ k: v.select_state(batch_hyps.states[k], i) for k, v in self.scorers.items() }, hs=batch_hyps.hs[i] if self.return_hs else [], ) for i in range(len(batch_hyps.length)) ]
[docs] def batch_beam( self, weighted_scores: torch.Tensor, ids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Batch-compute topk full token ids and partial token ids. Args: weighted_scores (torch.Tensor): The weighted sum scores for each tokens. Its shape is `(n_beam, self.vocab_size)`. ids (torch.Tensor): The partial token ids to compute topk. Its shape is `(n_beam, self.pre_beam_size)`. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The topk full (prev_hyp, new_token) ids and partial (prev_hyp, new_token) ids. Their shapes are all `(self.beam_size,)` """ top_ids = weighted_scores.view(-1).topk(self.beam_size)[1] # Because of the flatten above, `top_ids` is organized as: # [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK], # where V is `self.n_vocab` and K is `self.beam_size` if is_torch_1_9_plus: prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode="trunc") else: prev_hyp_ids = top_ids // self.n_vocab new_token_ids = top_ids % self.n_vocab return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
[docs] def init_hyp(self, x: torch.Tensor) -> BatchHypothesis: """Get an initial hypothesis data. Args: x (torch.Tensor): The encoder output feature Returns: Hypothesis: The initial hypothesis. """ init_states = dict() init_scores = dict() for k, d in self.scorers.items(): init_states[k] = d.batch_init_state(x) init_scores[k] = 0.0 # NOTE (Shih-Lun): added for OpenAI Whisper ASR primer = [self.sos] if self.hyp_primer is None else self.hyp_primer return self.batchfy( [ Hypothesis( score=0.0, scores=init_scores, states=init_states, hs=[], yseq=torch.tensor(primer, device=x.device), ) ] )
[docs] def score_full( self, hyp: BatchHypothesis, x: torch.Tensor, pre_x: torch.Tensor = None, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.full_scorers`. Args: hyp (Hypothesis): Hypothesis with prefix tokens to score x (torch.Tensor): Corresponding input feature pre_x (torch.Tensor): Encoded speech feature for sequential attn (T, D) Sequential attn computes attn first on pre_x then on x, thereby attending to two sources in sequence. Returns: Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of score dict of `hyp` that has string keys of `self.full_scorers` and tensor score values of shape: `(self.n_vocab,)`, and state dict that has string keys and state values of `self.full_scorers` """ scores = dict() states = dict() for k, d in self.full_scorers.items(): if "decoder" in k and self.return_hs: (scores[k], hs), states[k] = d.batch_score( hyp.yseq, hyp.states[k], x, return_hs=self.return_hs ) elif "decoder" in k and pre_x is not None: scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x, pre_x) else: scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x) if self.return_hs: return hs, scores, states return scores, states
[docs] def score_partial( self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor, pre_x: torch.Tensor = None, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.full_scorers`. Args: hyp (Hypothesis): Hypothesis with prefix tokens to score ids (torch.Tensor): 2D tensor of new partial tokens to score x (torch.Tensor): Corresponding input feature pre_x (torch.Tensor): Encoded speech feature for sequential attn (T, D) Sequential attn computes attn first on pre_x then on x, thereby attending to two sources in sequence. Returns: Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of score dict of `hyp` that has string keys of `self.full_scorers` and tensor score values of shape: `(self.n_vocab,)`, and state dict that has string keys and state values of `self.full_scorers` """ scores = dict() states = dict() for k, d in self.part_scorers.items(): if "ctc" in k and pre_x is not None: scores[k], states[k] = d.batch_score_partial( hyp.yseq, ids, hyp.states[k], pre_x ) else: scores[k], states[k] = d.batch_score_partial( hyp.yseq, ids, hyp.states[k], x ) return scores, states
[docs] def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: """Merge states for new hypothesis. Args: states: states of `self.full_scorers` part_states: states of `self.part_scorers` part_idx (int): The new token id for `part_scores` Returns: Dict[str, torch.Tensor]: The new score dict. Its keys are names of `self.full_scorers` and `self.part_scorers`. Its values are states of the scorers. """ new_states = dict() for k, v in states.items(): new_states[k] = v for k, v in part_states.items(): new_states[k] = v return new_states
[docs] def search( self, running_hyps: BatchHypothesis, x: torch.Tensor, pre_x: torch.Tensor = None, ) -> BatchHypothesis: """Search new tokens for running hypotheses and encoded speech x. Args: running_hyps (BatchHypothesis): Running hypotheses on beam x (torch.Tensor): Encoded speech feature (T, D) pre_x (torch.Tensor): Encoded speech feature for sequential attention (T, D) Returns: BatchHypothesis: Best sorted hypotheses """ n_batch = len(running_hyps) part_ids = None # no pre-beam # batch scoring weighted_scores = torch.zeros( n_batch, self.n_vocab, dtype=x.dtype, device=x.device ) if self.return_hs: hs, scores, states = self.score_full( running_hyps, x.expand(n_batch, *x.shape), pre_x=( pre_x.expand(n_batch, *pre_x.shape) if pre_x is not None else None ), ) else: scores, states = self.score_full( running_hyps, x.expand(n_batch, *x.shape), pre_x=( pre_x.expand(n_batch, *pre_x.shape) if pre_x is not None else None ), ) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] # partial scoring if self.do_pre_beam: pre_beam_scores = ( weighted_scores if self.pre_beam_score_key == "full" else scores[self.pre_beam_score_key] ) part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1] # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns # full-size score matrices, which has non-zero scores for part_ids and zeros # for others. part_scores, part_states = self.score_partial(running_hyps, part_ids, x, pre_x) for k in self.part_scorers: weighted_scores += self.weights[k] * part_scores[k] # add previous hyp scores weighted_scores += running_hyps.score.to( dtype=x.dtype, device=x.device ).unsqueeze(1) # TODO(karita): do not use list. use batch instead # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029 # update hyps best_hyps = [] prev_hyps = self.unbatchfy(running_hyps) for ( full_prev_hyp_id, full_new_token_id, part_prev_hyp_id, part_new_token_id, ) in zip(*self.batch_beam(weighted_scores, part_ids)): prev_hyp = prev_hyps[full_prev_hyp_id] if self.return_hs: new_hs = prev_hyp.hs + [hs[full_prev_hyp_id].squeeze(0)] else: new_hs = [] best_hyps.append( Hypothesis( score=weighted_scores[full_prev_hyp_id, full_new_token_id], yseq=self.append_token(prev_hyp.yseq, full_new_token_id), scores=self.merge_scores( prev_hyp.scores, {k: v[full_prev_hyp_id] for k, v in scores.items()}, full_new_token_id, {k: v[part_prev_hyp_id] for k, v in part_scores.items()}, part_new_token_id, ), states=self.merge_states( { k: self.full_scorers[k].select_state(v, full_prev_hyp_id) for k, v in states.items() }, { k: self.part_scorers[k].select_state( v, part_prev_hyp_id, part_new_token_id ) for k, v in part_states.items() }, part_new_token_id, ), hs=new_hs, ) ) return self.batchfy(best_hyps)
[docs] def post_process( self, i: int, maxlen: int, minlen: int, maxlenratio: float, running_hyps: BatchHypothesis, ended_hyps: List[Hypothesis], ) -> BatchHypothesis: """Perform post-processing of beam search iterations. Args: i (int): The length of hypothesis tokens. maxlen (int): The maximum length of tokens in beam search. maxlenratio (int): The maximum length ratio in beam search. running_hyps (BatchHypothesis): The running hypotheses in beam search. ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. Returns: BatchHypothesis: The new running hypotheses. """ n_batch = running_hyps.yseq.shape[0] logger.debug(f"the number of running hypothes: {n_batch}") if self.token_list is not None: logger.debug( "best hypo: " + "".join( [ self.token_list[x] for x in running_hyps.yseq[0, 1 : running_hyps.length[0]] ] ) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logger.info("adding <eos> in the last position in the loop") yseq_eos = torch.cat( ( running_hyps.yseq, torch.full( (n_batch, 1), self.eos, device=running_hyps.yseq.device, dtype=torch.int64, ), ), 1, ) running_hyps.yseq.resize_as_(yseq_eos) running_hyps.yseq[:] = yseq_eos running_hyps.length[:] = yseq_eos.shape[1] # add ended hypotheses to a final list, and removed them from current hypotheses # (this will be a probmlem, number of hyps < beam) is_eos = ( running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1] == self.eos ) for b in torch.nonzero(is_eos, as_tuple=False).view(-1): hyp = self._select(running_hyps, b) if i >= minlen: ended_hyps.append(hyp) remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1).cpu() return self._batch_select(running_hyps, remained_ids)