Source code for espnet2.asr_transducer.error_calculator

"""Error Calculator module for Transducer."""

from typing import List, Optional, Tuple

import torch

from espnet2.asr_transducer.beam_search_transducer import BeamSearchTransducer
from espnet2.asr_transducer.decoder.abs_decoder import AbsDecoder
from espnet2.asr_transducer.joint_network import JointNetwork


[docs]class ErrorCalculator: """Calculate CER and WER for transducer models. Args: decoder: Decoder module. joint_network: Joint Network module. token_list: List of token units. sym_space: Space symbol. sym_blank: Blank symbol. nstep: Maximum number of symbol expansions at each time step w/ mAES. report_cer: Whether to compute CER. report_wer: Whether to compute WER. """ def __init__( self, decoder: AbsDecoder, joint_network: JointNetwork, token_list: List[int], sym_space: str, sym_blank: str, nstep: int = 2, report_cer: bool = False, report_wer: bool = False, ) -> None: """Construct an ErrorCalculatorTransducer object.""" super().__init__() # (b-flo): Since the commit #8c9c851 we rely on the mAES algorithm for # validation instead of the default algorithm. # # With the addition of k2 pruned transducer loss, the number of emitted symbols # at each timestep can be restricted during training. Performing an unrestricted # (/ unconstrained) decoding without regard to the training conditions can lead # to huge performance degradation. It won't be an issue with mAES and the user # can now control the number of emitted symbols during validation. # # Also, under certain conditions, using the default algorithm can lead to a long # decoding procedure due to the loop break condition. Other algorithms, # such as mAES, won't be impacted by that. self.beam_search = BeamSearchTransducer( decoder=decoder, joint_network=joint_network, beam_size=2, search_type="maes", nstep=nstep, score_norm=False, ) self.decoder = decoder self.token_list = token_list self.space = sym_space self.blank = sym_blank self.report_cer = report_cer self.report_wer = report_wer def __call__( self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor, ) -> Tuple[Optional[float], Optional[float]]: """Calculate sentence-level WER or/and CER score for Transducer model. Args: encoder_out: Encoder output sequences. (B, T, D_enc) target: Target label ID sequences. (B, L) encoder_out_lens: Encoder output sequences length. (B,) Returns: : Sentence-level CER score. : Sentence-level WER score. """ cer, wer = None, None batchsize = int(encoder_out.size(0)) encoder_out = encoder_out.to(next(self.decoder.parameters()).device) batch_nbest = [ self.beam_search(encoder_out[b][: encoder_out_lens[b]]) for b in range(batchsize) ] pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] char_pred, char_target = self.convert_to_char(pred, target) if self.report_cer: cer = self.calculate_cer(char_pred, char_target) if self.report_wer: wer = self.calculate_wer(char_pred, char_target) return cer, wer
[docs] def convert_to_char( self, pred: torch.Tensor, target: torch.Tensor ) -> Tuple[List, List]: """Convert label ID sequences to character sequences. Args: pred: Prediction label ID sequences. (B, U) target: Target label ID sequences. (B, L) Returns: char_pred: Prediction character sequences. (B, ?) char_target: Target character sequences. (B, ?) """ char_pred, char_target = [], [] for i, pred_i in enumerate(pred): char_pred_i = [self.token_list[int(h)] for h in pred_i] char_target_i = [self.token_list[int(r)] for r in target[i]] char_pred_i = "".join(char_pred_i).replace(self.space, " ") char_pred_i = char_pred_i.replace(self.blank, "") char_target_i = "".join(char_target_i).replace(self.space, " ") char_target_i = char_target_i.replace(self.blank, "") char_pred.append(char_pred_i) char_target.append(char_target_i) return char_pred, char_target
[docs] def calculate_cer( self, char_pred: torch.Tensor, char_target: torch.Tensor ) -> float: """Calculate sentence-level CER score. Args: char_pred: Prediction character sequences. (B, ?) char_target: Target character sequences. (B, ?) Returns: : Average sentence-level CER score. """ import editdistance distances, lens = [], [] for i, char_pred_i in enumerate(char_pred): pred = char_pred_i.replace(" ", "") target = char_target[i].replace(" ", "") distances.append(editdistance.eval(pred, target)) lens.append(len(target)) return float(sum(distances)) / sum(lens)
[docs] def calculate_wer( self, char_pred: torch.Tensor, char_target: torch.Tensor ) -> float: """Calculate sentence-level WER score. Args: char_pred: Prediction character sequences. (B, ?) char_target: Target character sequences. (B, ?) Returns: : Average sentence-level WER score """ import editdistance distances, lens = [], [] for i, char_pred_i in enumerate(char_pred): pred = char_pred_i.replace("▁", " ").split() target = char_target[i].replace("▁", " ").split() distances.append(editdistance.eval(pred, target)) lens.append(len(target)) return float(sum(distances)) / sum(lens)