Source code for espnet2.asr.transducer.beam_search_transducer

"""Search algorithms for Transducer models."""

import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr_transducer.joint_network import JointNetwork
from espnet2.lm.transformer_lm import TransformerLM
from espnet.nets.pytorch_backend.transducer.utils import (
    is_prefix,
    recombine_hyps,
    select_k_expansions,
    subtract,
)


[docs]@dataclass class Hypothesis: """Default hypothesis definition for Transducer search algorithms.""" score: float yseq: List[int] dec_state: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]], torch.Tensor, ] lm_state: Union[Dict[str, Any], List[Any]] = None
[docs]@dataclass class ExtendedHypothesis(Hypothesis): """Extended hypothesis definition for NSC beam search and mAES.""" dec_out: List[torch.Tensor] = None lm_scores: torch.Tensor = None
[docs]class BeamSearchTransducer: """Beam search implementation for Transducer.""" def __init__( self, decoder: AbsDecoder, joint_network: JointNetwork, beam_size: int, lm: torch.nn.Module = None, lm_weight: float = 0.1, search_type: str = "default", max_sym_exp: int = 2, u_max: int = 50, nstep: int = 1, prefix_alpha: int = 1, expansion_gamma: int = 2.3, expansion_beta: int = 2, multi_blank_durations: List[int] = [], multi_blank_indices: List[int] = [], score_norm: bool = True, score_norm_during: bool = False, nbest: int = 1, token_list: Optional[List[str]] = None, ): """Initialize Transducer search module. Args: decoder: Decoder module. joint_network: Joint network module. beam_size: Beam size. lm: LM class. lm_weight: LM weight for soft fusion. search_type: Search algorithm to use during inference. max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) u_max: Maximum output sequence length. (ALSD) nstep: Number of maximum expansion steps at each time step. (NSC/mAES) prefix_alpha: Maximum prefix length in prefix search. (NSC/mAES) expansion_beta: Number of additional candidates for expanded hypotheses selection. (mAES) expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) multi_blank_durations: The duration of each blank token. (MBG) multi_blank_indices: The index of each blank token in token_list. (MBG) score_norm: Normalize final scores by length. ("default") score_norm_during: Normalize scores by length during search. (default, TSD, ALSD) nbest: Number of final hypothesis. """ self.decoder = decoder self.joint_network = joint_network self.beam_size = beam_size self.hidden_size = decoder.dunits self.vocab_size = decoder.odim self.sos = self.vocab_size - 1 self.token_list = token_list self.blank_id = decoder.blank_id if search_type == "mbg": self.beam_size = 1 self.multi_blank_durations = multi_blank_durations self.multi_blank_indices = multi_blank_indices self.search_algorithm = self.multi_blank_greedy_search elif self.beam_size <= 1: self.search_algorithm = self.greedy_search elif search_type == "default": self.search_algorithm = self.default_beam_search elif search_type == "tsd": if isinstance(lm, TransformerLM): raise NotImplementedError self.max_sym_exp = max_sym_exp self.search_algorithm = self.time_sync_decoding elif search_type == "alsd": if isinstance(lm, TransformerLM): raise NotImplementedError self.u_max = u_max self.search_algorithm = self.align_length_sync_decoding elif search_type == "nsc": if isinstance(lm, TransformerLM): raise NotImplementedError self.nstep = nstep self.prefix_alpha = prefix_alpha self.search_algorithm = self.nsc_beam_search elif search_type == "maes": if isinstance(lm, TransformerLM): raise NotImplementedError self.nstep = nstep if nstep > 1 else 2 self.prefix_alpha = prefix_alpha self.expansion_gamma = expansion_gamma assert self.vocab_size >= beam_size + expansion_beta, ( "beam_size (%d) + expansion_beta (%d) " "should be smaller or equal to vocabulary size (%d)." % (beam_size, expansion_beta, self.vocab_size) ) self.max_candidates = beam_size + expansion_beta self.search_algorithm = self.modified_adaptive_expansion_search else: raise NotImplementedError self.use_lm = lm is not None self.lm = lm self.lm_weight = lm_weight if self.use_lm and self.beam_size == 1: logging.warning("LM is provided but not used, since this is greedy search.") self.score_norm = score_norm self.score_norm_during = score_norm_during self.nbest = nbest def __call__( self, enc_out: torch.Tensor ) -> Union[List[Hypothesis], List[ExtendedHypothesis]]: """Perform beam search. Args: enc_out: Encoder output sequence. (T, D_enc) Returns: nbest_hyps: N-best decoding results """ self.decoder.set_device(enc_out.device) nbest_hyps = self.search_algorithm(enc_out) return nbest_hyps
[docs] def sort_nbest( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]] ) -> Union[List[Hypothesis], List[ExtendedHypothesis]]: """Sort hypotheses by score or score given sequence length. Args: hyps: Hypothesis. Return: hyps: Sorted hypothesis. """ if self.score_norm: hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) else: hyps.sort(key=lambda x: x.score, reverse=True) return hyps[: self.nbest]
[docs] def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: enc_out: Encoder output sequence. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} if self.use_lm: B[0].lm_state = self.lm.zero_state() for enc_out_t in enc_out: A = [] C = B enc_out_t = enc_out_t.unsqueeze(0) for v in range(self.max_sym_exp): D = [] beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( C, beam_state, cache, self.use_lm, ) beam_logp = torch.log_softmax( self.joint_network(enc_out_t, beam_dec_out), dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) ) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) ) if v < (self.max_sym_exp - 1): if self.use_lm: beam_lm_scores, beam_lm_states = self.lm.batch_score( beam_lm_tokens, [c.lm_state for c in C], None ) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = beam_lm_states[i] D.append(new_hyp) if self.score_norm_during: C = sorted(D, key=lambda x: x.score / len(x.yseq), reverse=True)[ :beam ] else: C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] if self.score_norm_during: B = sorted(A, key=lambda x: x.score / len(x.yseq), reverse=True)[:beam] else: B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
[docs] def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoder output sequences. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) t_max = int(enc_out.size(0)) u_max = min(self.u_max, (t_max - 1)) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] cache = {} if self.use_lm: B[0].lm_state = self.lm.zero_state() for i in range(t_max + u_max): A = [] B_ = [] B_enc_out = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u if t > (t_max - 1): continue B_.append(hyp) B_enc_out.append((t, enc_out[t])) if B_: beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, self.use_lm, ) beam_enc_out = torch.stack([x[1] for x in B_enc_out]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out), dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.use_lm: beam_lm_scores, beam_lm_states = self.lm.batch_score( beam_lm_tokens, [b.lm_state for b in B_], None, ) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if B_enc_out[i][0] == (t_max - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = beam_lm_states[i] A.append(new_hyp) if self.score_norm_during: B = sorted(A, key=lambda x: x.score / len(x.yseq), reverse=True)[ :beam ] else: B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B