"""Utility functions for Transducer models."""
import os
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis, Hypothesis
[docs]def valid_aux_encoder_output_layers(
aux_layer_id: List[int],
enc_num_layers: int,
use_symm_kl_div_loss: bool,
subsample: List[int],
) -> List[int]:
"""Check whether provided auxiliary encoder layer IDs are valid.
Return the valid list sorted with duplicates removed.
Args:
aux_layer_id: Auxiliary encoder layer IDs.
enc_num_layers: Number of encoder layers.
use_symm_kl_div_loss: Whether symmetric KL divergence loss is used.
subsample: Subsampling rate per layer.
Returns:
valid: Valid list of auxiliary encoder layers.
"""
if (
not isinstance(aux_layer_id, list)
or not aux_layer_id
or not all(isinstance(layer, int) for layer in aux_layer_id)
):
raise ValueError(
"aux-transducer-loss-enc-output-layers option takes a list of layer IDs."
" Correct argument format is: '[0, 1]'"
)
sorted_list = sorted(aux_layer_id, key=int, reverse=False)
valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list))
if sorted_list != valid:
raise ValueError(
"Provided argument for aux-transducer-loss-enc-output-layers is incorrect."
" IDs should be between [0, %d]" % enc_num_layers
)
if use_symm_kl_div_loss:
sorted_list += [enc_num_layers]
for n in range(1, len(sorted_list)):
sub_range = subsample[(sorted_list[n - 1] + 1) : sorted_list[n] + 1]
valid_shape = [False if n > 1 else True for n in sub_range]
if False in valid_shape:
raise ValueError(
"Encoder layers %d and %d have different shape due to subsampling."
" Symmetric KL divergence loss doesn't cover such case for now."
% (sorted_list[n - 1], sorted_list[n])
)
return valid
[docs]def is_prefix(x: List[int], pref: List[int]) -> bool:
"""Check if pref is a prefix of x.
Args:
x: Label ID sequence.
pref: Prefix label ID sequence.
Returns:
: Whether pref is a prefix of x.
"""
if len(pref) >= len(x):
return False
for i in range(len(pref) - 1, -1, -1):
if pref[i] != x[i]:
return False
return True
[docs]def subtract(
x: List[ExtendedHypothesis], subset: List[ExtendedHypothesis]
) -> List[ExtendedHypothesis]:
"""Remove elements of subset if corresponding label ID sequence already exist in x.
Args:
x: Set of hypotheses.
subset: Subset of x.
Returns:
final: New set of hypotheses.
"""
final = []
for x_ in x:
if any(x_.yseq == sub.yseq for sub in subset):
continue
final.append(x_)
return final
[docs]def select_k_expansions(
hyps: List[ExtendedHypothesis],
topk_idxs: torch.Tensor,
topk_logps: torch.Tensor,
gamma: float,
) -> List[ExtendedHypothesis]:
"""Return K hypotheses candidates for expansion from a list of hypothesis.
K candidates are selected according to the extended hypotheses probabilities
and a prune-by-value method. Where K is equal to beam_size + beta.
Args:
hyps: Hypotheses.
topk_idxs: Indices of candidates hypothesis.
topk_logps: Log-probabilities for hypotheses expansions.
gamma: Allowed logp difference for prune-by-value method.
Return:
k_expansions: Best K expansion hypotheses candidates.
"""
k_expansions = []
for i, hyp in enumerate(hyps):
hyp_i = [
(int(k), hyp.score + float(v)) for k, v in zip(topk_idxs[i], topk_logps[i])
]
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
k_expansions.append(
sorted(
filter(lambda x: (k_best_exp - gamma) <= x[1], hyp_i),
key=lambda x: x[1],
reverse=True,
)
)
return k_expansions
[docs]def select_lm_state(
lm_states: Union[List[Any], Dict[str, Any]],
idx: int,
lm_layers: int,
is_wordlm: bool,
) -> Union[List[Any], Dict[str, Any]]:
"""Get ID state from LM hidden states.
Args:
lm_states: LM hidden states.
idx: LM state ID to extract.
lm_layers: Number of LM layers.
is_wordlm: Whether provided LM is a word-level LM.
Returns:
idx_state: LM hidden state for given ID.
"""
if is_wordlm:
idx_state = lm_states[idx]
else:
idx_state = {}
idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)]
idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)]
return idx_state
[docs]def create_lm_batch_states(
lm_states: Union[List[Any], Dict[str, Any]], lm_layers, is_wordlm: bool
) -> Union[List[Any], Dict[str, Any]]:
"""Create LM hidden states.
Args:
lm_states: LM hidden states.
lm_layers: Number of LM layers.
is_wordlm: Whether provided LM is a word-level LM.
Returns:
new_states: LM hidden states.
"""
if is_wordlm:
return lm_states
new_states = {}
new_states["c"] = [
torch.stack([state["c"][layer] for state in lm_states])
for layer in range(lm_layers)
]
new_states["h"] = [
torch.stack([state["h"][layer] for state in lm_states])
for layer in range(lm_layers)
]
return new_states
[docs]def init_lm_state(lm_model: torch.nn.Module):
"""Initialize LM hidden states.
Args:
lm_model: LM module.
Returns:
lm_state: Initial LM hidden states.
"""
lm_layers = len(lm_model.rnn)
lm_units_typ = lm_model.typ
lm_units = lm_model.n_units
p = next(lm_model.parameters())
h = [
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype)
for _ in range(lm_layers)
]
lm_state = {"h": h}
if lm_units_typ == "lstm":
lm_state["c"] = [
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype)
for _ in range(lm_layers)
]
return lm_state
[docs]def recombine_hyps(hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Recombine hypotheses with same label ID sequence.
Args:
hyps: Hypotheses.
Returns:
final: Recombined hypotheses.
"""
final = []
for hyp in hyps:
seq_final = [f.yseq for f in final if f.yseq]
if hyp.yseq in seq_final:
seq_pos = seq_final.index(hyp.yseq)
final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score)
else:
final.append(hyp)
return final
[docs]def pad_sequence(labels: List[int], pad_id: int) -> List[int]:
"""Left pad label ID sequences.
Args:
labels: Label ID sequence.
pad_id: Padding symbol ID.
Returns:
final: Padded label ID sequences.
"""
maxlen = max(len(x) for x in labels)
final = [([pad_id] * (maxlen - len(x))) + x for x in labels]
return final
[docs]def check_state(
state: List[Optional[torch.Tensor]], max_len: int, pad_id: int
) -> List[Optional[torch.Tensor]]:
"""Check decoder hidden states and left pad or trim if necessary.
Args:
state: Decoder hidden states. [N x (?, D_dec)]
max_len: maximum sequence length.
pad_id: Padding symbol ID.
Returns:
final: Decoder hidden states. [N x (1, max_len, D_dec)]
"""
if state is None or max_len < 1 or state[0].size(1) == max_len:
return state
curr_len = state[0].size(1)
if curr_len > max_len:
trim_val = int(state[0].size(1) - max_len)
for i, s in enumerate(state):
state[i] = s[:, trim_val:, :]
else:
layers = len(state)
ddim = state[0].size(2)
final_dims = (1, max_len, ddim)
final = [state[0].data.new(*final_dims).fill_(pad_id) for _ in range(layers)]
for i, s in enumerate(state):
final[i][:, (max_len - s.size(1)) : max_len, :] = s
return final
return state
[docs]def check_batch_states(states, max_len, pad_id):
"""Check decoder hidden states and left pad or trim if necessary.
Args:
state: Decoder hidden states. [N x (B, ?, D_dec)]
max_len: maximum sequence length.
pad_id: Padding symbol ID.
Returns:
final: Decoder hidden states. [N x (B, max_len, dec_dim)]
"""
final_dims = (len(states), max_len, states[0].size(1))
final = states[0].data.new(*final_dims).fill_(pad_id)
for i, s in enumerate(states):
curr_len = s.size(0)
if curr_len < max_len:
final[i, (max_len - curr_len) : max_len, :] = s
else:
final[i, :, :] = s[(curr_len - max_len) :, :]
return final
[docs]def custom_torch_load(model_path: str, model: torch.nn.Module, training: bool = True):
"""Load Transducer model with training-only modules and parameters removed.
Args:
model_path: Model path.
model: Transducer model.
"""
if "snapshot" in os.path.basename(model_path):
model_state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)["model"]
else:
model_state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)
if not training:
task_keys = ("mlp", "ctc_lin", "kl_div", "lm_lin", "error_calculator")
model_state_dict = {
k: v
for k, v in model_state_dict.items()
if not any(mod in k for mod in task_keys)
}
model.load_state_dict(model_state_dict)
del model_state_dict