Source code for espnet.nets.beam_search_transducer

"""Search algorithms for Transducer models."""

import logging
from typing import List, Union

import numpy as np
import torch

from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork
from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder
from espnet.nets.pytorch_backend.transducer.utils import (
    create_lm_batch_states,
    init_lm_state,
    is_prefix,
    recombine_hyps,
    select_k_expansions,
    select_lm_state,
    subtract,
)
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis, Hypothesis


[docs]class BeamSearchTransducer: """Beam search implementation for Transducer.""" def __init__( self, decoder: Union[RNNDecoder, CustomDecoder], joint_network: JointNetwork, beam_size: int, lm: torch.nn.Module = None, lm_weight: float = 0.1, search_type: str = "default", max_sym_exp: int = 2, u_max: int = 50, nstep: int = 1, prefix_alpha: int = 1, expansion_gamma: int = 2.3, expansion_beta: int = 2, score_norm: bool = True, softmax_temperature: float = 1.0, nbest: int = 1, quantization: bool = False, ): """Initialize Transducer search module. Args: decoder: Decoder module. joint_network: Joint network module. beam_size: Beam size. lm: LM class. lm_weight: LM weight for soft fusion. search_type: Search algorithm to use during inference. max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) u_max: Maximum output sequence length. (ALSD) nstep: Number of maximum expansion steps at each time step. (NSC/mAES) prefix_alpha: Maximum prefix length in prefix search. (NSC/mAES) expansion_beta: Number of additional candidates for expanded hypotheses selection. (mAES) expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) score_norm: Normalize final scores by length. ("default") softmax_temperature: Penalization term for softmax function. nbest: Number of final hypothesis. quantization: Whether dynamic quantization is used. """ self.decoder = decoder self.joint_network = joint_network self.beam_size = beam_size self.hidden_size = decoder.dunits self.vocab_size = decoder.odim self.blank_id = decoder.blank_id if self.beam_size <= 1: self.search_algorithm = self.greedy_search elif search_type == "default": self.search_algorithm = self.default_beam_search elif search_type == "tsd": self.max_sym_exp = max_sym_exp self.search_algorithm = self.time_sync_decoding elif search_type == "alsd": self.u_max = u_max self.search_algorithm = self.align_length_sync_decoding elif search_type == "nsc": self.nstep = nstep self.prefix_alpha = prefix_alpha self.search_algorithm = self.nsc_beam_search elif search_type == "maes": self.nstep = nstep if nstep > 1 else 2 self.prefix_alpha = prefix_alpha self.expansion_gamma = expansion_gamma assert self.vocab_size >= beam_size + expansion_beta, ( "beam_size (%d) + expansion_beta (%d) " "should be smaller or equal to vocabulary size (%d)." % (beam_size, expansion_beta, self.vocab_size) ) self.max_candidates = beam_size + expansion_beta self.search_algorithm = self.modified_adaptive_expansion_search else: raise NotImplementedError if lm is not None: self.use_lm = True self.lm = lm self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor self.lm_layers = len(self.lm_predictor.rnn) self.lm_weight = lm_weight else: self.use_lm = False if softmax_temperature > 1.0 and lm is not None: logging.warning( "Softmax temperature is not supported with LM decoding." "Setting softmax-temperature value to 1.0." ) self.softmax_temperature = 1.0 else: self.softmax_temperature = softmax_temperature self.quantization = quantization self.score_norm = score_norm self.nbest = nbest def __call__( self, enc_out: torch.Tensor ) -> Union[List[Hypothesis], List[ExtendedHypothesis]]: """Perform beam search. Args: enc_out: Encoder output sequence. (T, D_enc) Returns: nbest_hyps: N-best decoding results """ self.decoder.set_device(enc_out.device) nbest_hyps = self.search_algorithm(enc_out) return nbest_hyps
[docs] def sort_nbest( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]] ) -> Union[List[Hypothesis], List[ExtendedHypothesis]]: """Sort hypotheses by score or score given sequence length. Args: hyps: Hypothesis. Return: hyps: Sorted hypothesis. """ if self.score_norm: hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) else: hyps.sort(key=lambda x: x.score, reverse=True) return hyps[: self.nbest]
[docs] def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: enc_out: Encoder output sequence. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} if self.use_lm and not self.is_wordlm: B[0].lm_state = init_lm_state(self.lm_predictor) for enc_out_t in enc_out: A = [] C = B enc_out_t = enc_out_t.unsqueeze(0) for v in range(self.max_sym_exp): D = [] beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( C, beam_state, cache, self.use_lm, ) beam_logp = torch.log_softmax( self.joint_network(enc_out_t, beam_dec_out) / self.softmax_temperature, dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) ) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) ) if v < (self.max_sym_exp - 1): if self.use_lm: beam_lm_states = create_lm_batch_states( [c.lm_state for c in C], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(C) ) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) D.append(new_hyp) C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
[docs] def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoder output sequences. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) t_max = int(enc_out.size(0)) u_max = min(self.u_max, (t_max - 1)) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] cache = {} if self.use_lm and not self.is_wordlm: B[0].lm_state = init_lm_state(self.lm_predictor) for i in range(t_max + u_max): A = [] B_ = [] B_enc_out = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u if t > (t_max - 1): continue B_.append(hyp) B_enc_out.append((t, enc_out[t])) if B_: beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, self.use_lm, ) beam_enc_out = torch.stack([x[1] for x in B_enc_out]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.use_lm: beam_lm_states = create_lm_batch_states( [b.lm_state for b in B_], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_) ) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if B_enc_out[i][0] == (t_max - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B