Source code for espnet.nets.pytorch_backend.transducer.utils

"""Utility functions for Transducer models."""

import os
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch

from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis, Hypothesis


[docs]def get_decoder_input( labels: torch.Tensor, blank_id: int, ignore_id: int ) -> torch.Tensor: """Prepare decoder input. Args: labels: Label ID sequences. (B, L) Returns: decoder_input: Label ID sequences with blank prefix. (B, U) """ device = labels.device labels_unpad = [label[label != ignore_id] for label in labels] blank = labels[0].new([blank_id]) decoder_input = pad_list( [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id ).to(device) return decoder_input
[docs]def valid_aux_encoder_output_layers( aux_layer_id: List[int], enc_num_layers: int, use_symm_kl_div_loss: bool, subsample: List[int], ) -> List[int]: """Check whether provided auxiliary encoder layer IDs are valid. Return the valid list sorted with duplicates removed. Args: aux_layer_id: Auxiliary encoder layer IDs. enc_num_layers: Number of encoder layers. use_symm_kl_div_loss: Whether symmetric KL divergence loss is used. subsample: Subsampling rate per layer. Returns: valid: Valid list of auxiliary encoder layers. """ if ( not isinstance(aux_layer_id, list) or not aux_layer_id or not all(isinstance(layer, int) for layer in aux_layer_id) ): raise ValueError( "aux-transducer-loss-enc-output-layers option takes a list of layer IDs." " Correct argument format is: '[0, 1]'" ) sorted_list = sorted(aux_layer_id, key=int, reverse=False) valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list)) if sorted_list != valid: raise ValueError( "Provided argument for aux-transducer-loss-enc-output-layers is incorrect." " IDs should be between [0, %d]" % enc_num_layers ) if use_symm_kl_div_loss: sorted_list += [enc_num_layers] for n in range(1, len(sorted_list)): sub_range = subsample[(sorted_list[n - 1] + 1) : sorted_list[n] + 1] valid_shape = [False if n > 1 else True for n in sub_range] if False in valid_shape: raise ValueError( "Encoder layers %d and %d have different shape due to subsampling." " Symmetric KL divergence loss doesn't cover such case for now." % (sorted_list[n - 1], sorted_list[n]) ) return valid
[docs]def is_prefix(x: List[int], pref: List[int]) -> bool: """Check if pref is a prefix of x. Args: x: Label ID sequence. pref: Prefix label ID sequence. Returns: : Whether pref is a prefix of x. """ if len(pref) >= len(x): return False for i in range(len(pref) - 1, -1, -1): if pref[i] != x[i]: return False return True
[docs]def subtract( x: List[ExtendedHypothesis], subset: List[ExtendedHypothesis] ) -> List[ExtendedHypothesis]: """Remove elements of subset if corresponding label ID sequence already exist in x. Args: x: Set of hypotheses. subset: Subset of x. Returns: final: New set of hypotheses. """ final = [] for x_ in x: if any(x_.yseq == sub.yseq for sub in subset): continue final.append(x_) return final
[docs]def select_k_expansions( hyps: List[ExtendedHypothesis], topk_idxs: torch.Tensor, topk_logps: torch.Tensor, gamma: float, ) -> List[ExtendedHypothesis]: """Return K hypotheses candidates for expansion from a list of hypothesis. K candidates are selected according to the extended hypotheses probabilities and a prune-by-value method. Where K is equal to beam_size + beta. Args: hyps: Hypotheses. topk_idxs: Indices of candidates hypothesis. topk_logps: Log-probabilities for hypotheses expansions. gamma: Allowed logp difference for prune-by-value method. Return: k_expansions: Best K expansion hypotheses candidates. """ k_expansions = [] for i, hyp in enumerate(hyps): hyp_i = [ (int(k), hyp.score + float(v)) for k, v in zip(topk_idxs[i], topk_logps[i]) ] k_best_exp = max(hyp_i, key=lambda x: x[1])[1] k_expansions.append( sorted( filter(lambda x: (k_best_exp - gamma) <= x[1], hyp_i), key=lambda x: x[1], reverse=True, ) ) return k_expansions
[docs]def select_lm_state( lm_states: Union[List[Any], Dict[str, Any]], idx: int, lm_layers: int, is_wordlm: bool, ) -> Union[List[Any], Dict[str, Any]]: """Get ID state from LM hidden states. Args: lm_states: LM hidden states. idx: LM state ID to extract. lm_layers: Number of LM layers. is_wordlm: Whether provided LM is a word-level LM. Returns: idx_state: LM hidden state for given ID. """ if is_wordlm: idx_state = lm_states[idx] else: idx_state = {} idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)] idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)] return idx_state
[docs]def create_lm_batch_states( lm_states: Union[List[Any], Dict[str, Any]], lm_layers, is_wordlm: bool ) -> Union[List[Any], Dict[str, Any]]: """Create LM hidden states. Args: lm_states: LM hidden states. lm_layers: Number of LM layers. is_wordlm: Whether provided LM is a word-level LM. Returns: new_states: LM hidden states. """ if is_wordlm: return lm_states new_states = {} new_states["c"] = [ torch.stack([state["c"][layer] for state in lm_states]) for layer in range(lm_layers) ] new_states["h"] = [ torch.stack([state["h"][layer] for state in lm_states]) for layer in range(lm_layers) ] return new_states
[docs]def init_lm_state(lm_model: torch.nn.Module): """Initialize LM hidden states. Args: lm_model: LM module. Returns: lm_state: Initial LM hidden states. """ lm_layers = len(lm_model.rnn) lm_units_typ = lm_model.typ lm_units = lm_model.n_units p = next(lm_model.parameters()) h = [ torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) for _ in range(lm_layers) ] lm_state = {"h": h} if lm_units_typ == "lstm": lm_state["c"] = [ torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) for _ in range(lm_layers) ] return lm_state
[docs]def recombine_hyps(hyps: List[Hypothesis]) -> List[Hypothesis]: """Recombine hypotheses with same label ID sequence. Args: hyps: Hypotheses. Returns: final: Recombined hypotheses. """ final = [] for hyp in hyps: seq_final = [f.yseq for f in final if f.yseq] if hyp.yseq in seq_final: seq_pos = seq_final.index(hyp.yseq) final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) else: final.append(hyp) return final
[docs]def pad_sequence(labels: List[int], pad_id: int) -> List[int]: """Left pad label ID sequences. Args: labels: Label ID sequence. pad_id: Padding symbol ID. Returns: final: Padded label ID sequences. """ maxlen = max(len(x) for x in labels) final = [([pad_id] * (maxlen - len(x))) + x for x in labels] return final
[docs]def check_state( state: List[Optional[torch.Tensor]], max_len: int, pad_id: int ) -> List[Optional[torch.Tensor]]: """Check decoder hidden states and left pad or trim if necessary. Args: state: Decoder hidden states. [N x (?, D_dec)] max_len: maximum sequence length. pad_id: Padding symbol ID. Returns: final: Decoder hidden states. [N x (1, max_len, D_dec)] """ if state is None or max_len < 1 or state[0].size(1) == max_len: return state curr_len = state[0].size(1) if curr_len > max_len: trim_val = int(state[0].size(1) - max_len) for i, s in enumerate(state): state[i] = s[:, trim_val:, :] else: layers = len(state) ddim = state[0].size(2) final_dims = (1, max_len, ddim) final = [state[0].data.new(*final_dims).fill_(pad_id) for _ in range(layers)] for i, s in enumerate(state): final[i][:, (max_len - s.size(1)) : max_len, :] = s return final return state
[docs]def check_batch_states(states, max_len, pad_id): """Check decoder hidden states and left pad or trim if necessary. Args: state: Decoder hidden states. [N x (B, ?, D_dec)] max_len: maximum sequence length. pad_id: Padding symbol ID. Returns: final: Decoder hidden states. [N x (B, max_len, dec_dim)] """ final_dims = (len(states), max_len, states[0].size(1)) final = states[0].data.new(*final_dims).fill_(pad_id) for i, s in enumerate(states): curr_len = s.size(0) if curr_len < max_len: final[i, (max_len - curr_len) : max_len, :] = s else: final[i, :, :] = s[(curr_len - max_len) :, :] return final
[docs]def custom_torch_load(model_path: str, model: torch.nn.Module, training: bool = True): """Load Transducer model with training-only modules and parameters removed. Args: model_path: Model path. model: Transducer model. """ if "snapshot" in os.path.basename(model_path): model_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage )["model"] else: model_state_dict = torch.load( model_path, map_location=lambda storage, loc: storage ) if not training: task_keys = ("mlp", "ctc_lin", "kl_div", "lm_lin", "error_calculator") model_state_dict = { k: v for k, v in model_state_dict.items() if not any(mod in k for mod in task_keys) } model.load_state_dict(model_state_dict) del model_state_dict