Source code for espnet.nets.batch_beam_search_online_sim

"""Parallel beam search module for online simulation."""

import logging
from pathlib import Path
from typing import List

import torch
import yaml

from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect


[docs]class BatchBeamSearchOnlineSim(BatchBeamSearch): """Online beam search implementation. This simulates streaming decoding. It requires encoded features of entire utterance and extracts block by block from it as it shoud be done in streaming processing. This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH" (https://arxiv.org/abs/2006.14941). """
[docs] def set_streaming_config(self, asr_config: str): """Set config file for streaming decoding. Args: asr_config (str): The config file for asr training """ train_config_file = Path(asr_config) self.block_size = None self.hop_size = None self.look_ahead = None config = None with train_config_file.open("r", encoding="utf-8") as f: args = yaml.safe_load(f) if "encoder_conf" in args.keys(): if "block_size" in args["encoder_conf"].keys(): self.block_size = args["encoder_conf"]["block_size"] if "hop_size" in args["encoder_conf"].keys(): self.hop_size = args["encoder_conf"]["hop_size"] if "look_ahead" in args["encoder_conf"].keys(): self.look_ahead = args["encoder_conf"]["look_ahead"] elif "config" in args.keys(): config = args["config"] if config is None: logging.info( "Cannot find config file for streaming decoding: " + "apply batch beam search instead." ) return if ( self.block_size is None or self.hop_size is None or self.look_ahead is None ) and config is not None: config_file = Path(config) with config_file.open("r", encoding="utf-8") as f: args = yaml.safe_load(f) if "encoder_conf" in args.keys(): enc_args = args["encoder_conf"] if enc_args and "block_size" in enc_args: self.block_size = enc_args["block_size"] if enc_args and "hop_size" in enc_args: self.hop_size = enc_args["hop_size"] if enc_args and "look_ahead" in enc_args: self.look_ahead = enc_args["look_ahead"]
[docs] def set_block_size(self, block_size: int): """Set block size for streaming decoding. Args: block_size (int): The block size of encoder """ self.block_size = block_size
[docs] def set_hop_size(self, hop_size: int): """Set hop size for streaming decoding. Args: hop_size (int): The hop size of encoder """ self.hop_size = hop_size
[docs] def set_look_ahead(self, look_ahead: int): """Set look ahead size for streaming decoding. Args: look_ahead (int): The look ahead size of encoder """ self.look_ahead = look_ahead
[docs] def forward( self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 ) -> List[Hypothesis]: """Perform beam search. Args: x (torch.Tensor): Encoded speech feature (T, D) maxlenratio (float): Input length ratio to obtain max output length. If maxlenratio=0.0 (default), it uses a end-detect function to automatically find maximum hypothesis lengths minlenratio (float): Input length ratio to obtain min output length. Returns: list[Hypothesis]: N-best decoding results """ self.conservative = True # always true if self.block_size and self.hop_size and self.look_ahead: cur_end_frame = int(self.block_size - self.look_ahead) else: cur_end_frame = x.shape[0] process_idx = 0 if cur_end_frame < x.shape[0]: h = x.narrow(0, 0, cur_end_frame) else: h = x # set length bounds if maxlenratio == 0: maxlen = x.shape[0] else: maxlen = max(1, int(maxlenratio * x.size(0))) if minlenratio < 0: minlen = -1 * int(minlenratio) else: minlen = int(minlenratio * x.size(0)) logging.info("decoder input length: " + str(x.shape[0])) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # main loop of prefix search running_hyps = self.init_hyp(h) prev_hyps = [] ended_hyps = [] prev_repeat = False continue_decode = True while continue_decode: move_to_next_block = False if cur_end_frame < x.shape[0]: h = x.narrow(0, 0, cur_end_frame) else: h = x # extend states for ctc self.extend(h, running_hyps) while process_idx < maxlen: logging.debug("position " + str(process_idx)) best = self.search(running_hyps, h) if process_idx == maxlen - 1: # end decoding running_hyps = self.post_process( process_idx, maxlen, minlen, maxlenratio, best, ended_hyps ) n_batch = best.yseq.shape[0] local_ended_hyps = [] is_local_eos = ( best.yseq[torch.arange(n_batch), best.length - 1] == self.eos ) for i in range(is_local_eos.shape[0]): if is_local_eos[i]: hyp = self._select(best, i) local_ended_hyps.append(hyp) # NOTE(tsunoo): check repetitions here # This is a implicit implementation of # Eq (11) in https://arxiv.org/abs/2006.14941 # A flag prev_repeat is used instead of using set elif ( not prev_repeat and best.yseq[i, -1] in best.yseq[i, :-1] and cur_end_frame < x.shape[0] ): move_to_next_block = True prev_repeat = True if maxlenratio == 0.0 and end_detect( [lh.asdict() for lh in local_ended_hyps], process_idx ): logging.info(f"end detected at {process_idx}") continue_decode = False break if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]: move_to_next_block = True if move_to_next_block: if ( self.hop_size and cur_end_frame + int(self.hop_size) + int(self.look_ahead) < x.shape[0] ): cur_end_frame += int(self.hop_size) else: cur_end_frame = x.shape[0] logging.debug("Going to next block: %d", cur_end_frame) if process_idx > 1 and len(prev_hyps) > 0 and self.conservative: running_hyps = prev_hyps process_idx -= 1 prev_hyps = [] break prev_repeat = False prev_hyps = running_hyps running_hyps = self.post_process( process_idx, maxlen, minlen, maxlenratio, best, ended_hyps ) if cur_end_frame >= x.shape[0]: for hyp in local_ended_hyps: ended_hyps.append(hyp) if len(running_hyps) == 0: logging.info("no hypothesis. Finish decoding.") continue_decode = False break else: logging.debug(f"remained hypotheses: {len(running_hyps)}") # increment number process_idx += 1 if self.normalize_length: # Note (Jinchuan): -1 since hyp starts with <sos> and # initially has score of 0.0 nbest_hyps = sorted( ended_hyps, key=lambda x: x.score / (len(x.yseq) - 1), reverse=True ) else: nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) return ( [] if minlenratio < 0.1 else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) ) # report the best result best = nbest_hyps[0] for k, v in best.scores.items(): logging.info( f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" ) logging.info(f"total log probability: {best.score:.2f}") logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: logging.info( "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n" ) if best.yseq[1:-1].shape[0] == x.shape[0]: logging.warning( "best hypo length: {} == max output length: {}".format( best.yseq[1:-1].shape[0], maxlen ) ) logging.warning( "decoding may be stopped by the max output length limitation, " + "please consider to increase the maxlenratio." ) return nbest_hyps
[docs] def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]: """Extend probabilities and states with more encoded chunks. Args: x (torch.Tensor): The extended encoder output feature hyps (Hypothesis): Current list of hypothesis Returns: Hypothesis: The extended hypothesis """ for k, d in self.scorers.items(): if hasattr(d, "extend_prob"): d.extend_prob(x) if hasattr(d, "extend_state"): hyps.states[k] = d.extend_state(hyps.states[k])