Source code for espnet.nets.pytorch_backend.streaming.window

import torch

# TODO(pzelasko): Currently allows half-streaming only;
#  needs streaming attention decoder implementation
[docs]class WindowStreamingE2E(object): """WindowStreamingE2E constructor. :param E2E e2e: E2E ASR object :param recog_args: arguments for "recognize" method of E2E """ def __init__(self, e2e, recog_args, rnnlm=None): self._e2e = e2e self._recog_args = recog_args self._char_list = e2e.char_list self._rnnlm = rnnlm self._e2e.eval() self._offset = 0 self._previous_encoder_recurrent_state = None self._encoder_states = [] self._ctc_posteriors = [] self._last_recognition = None assert ( self._recog_args.ctc_weight > 0.0 ), "WindowStreamingE2E works only with combined CTC and attention decoders."
[docs] def accept_input(self, x): """Call this method each time a new batch of input is available.""" h, ilen = self._e2e.subsample_frames(x) # Streaming encoder h, _, self._previous_encoder_recurrent_state = self._e2e.enc( h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state ) self._encoder_states.append(h.squeeze(0)) # CTC posteriors for the incoming audio self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0))
def _input_window_for_decoder(self, use_all=False): if use_all: return (, dim=0),, dim=0), ) def select_unprocessed_windows(window_tensors): last_offset = self._offset offset_traversed = 0 selected_windows = [] for es in window_tensors: if offset_traversed > last_offset: selected_windows.append(es) continue offset_traversed += es.size(1) return, dim=0) return ( select_unprocessed_windows(self._encoder_states), select_unprocessed_windows(self._ctc_posteriors), )
[docs] def decode_with_attention_offline(self): """Run the attention decoder offline. Works even if the previous layers (encoder and CTC decoder) were being run in the online mode. This method should be run after all the audio has been consumed. This is used mostly to compare the results between offline and online implementation of the previous layers. """ h, lpz = self._input_window_for_decoder(use_all=True) return self._e2e.dec.recognize_beam( h, lpz, self._recog_args, self._char_list, self._rnnlm )