Source code for espnet.nets.pytorch_backend.transducer.error_calculator

"""CER/WER computation for Transducer model."""

from typing import List, Tuple, Union

import torch

from espnet.nets.beam_search_transducer import BeamSearchTransducer
from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork
from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder


[docs]class ErrorCalculator(object): """CER and WER computation for Transducer model. Args: decoder: Decoder module. joint_network: Joint network module. token_list: Set of unique labels. sym_space: Space symbol. sym_blank: Blank symbol. report_cer: Whether to compute CER. report_wer: Whether to compute WER. """ def __init__( self, decoder: Union[RNNDecoder, CustomDecoder], joint_network: JointNetwork, token_list: List[int], sym_space: str, sym_blank: str, report_cer: bool = False, report_wer: bool = False, ): """Construct an ErrorCalculator object for Transducer model.""" super().__init__() self.beam_search = BeamSearchTransducer( decoder=decoder, joint_network=joint_network, beam_size=2, search_type="default", ) self.decoder = decoder self.token_list = token_list self.space = sym_space self.blank = sym_blank self.report_cer = report_cer self.report_wer = report_wer def __call__( self, enc_out: torch.Tensor, target: torch.Tensor ) -> Tuple[float, float]: """Calculate sentence-level CER/WER score for hypotheses sequences. Args: enc_out: Encoder output sequences. (B, T, D_enc) target: Target label ID sequences. (B, L) Returns: cer: Sentence-level CER score. wer: Sentence-level WER score. """ cer, wer = None, None batchsize = int(enc_out.size(0)) batch_nbest = [] enc_out = enc_out.to(next(self.decoder.parameters()).device) for b in range(batchsize): nbest_hyps = self.beam_search(enc_out[b]) batch_nbest.append(nbest_hyps[-1]) batch_nbest = [nbest_hyp.yseq[1:] for nbest_hyp in batch_nbest] hyps, refs = self.convert_to_char(batch_nbest, target.cpu()) if self.report_cer: cer = self.calculate_cer(hyps, refs) if self.report_wer: wer = self.calculate_wer(hyps, refs) return cer, wer
[docs] def convert_to_char( self, hyps: torch.Tensor, refs: torch.Tensor ) -> Tuple[List, List]: """Convert label ID sequences to character. Args: hyps: Hypotheses sequences. (B, L) refs: References sequences. (B, L) Returns: char_hyps: Character list of hypotheses. char_hyps: Character list of references. """ char_hyps, char_refs = [], [] for i, hyp in enumerate(hyps): hyp_i = [self.token_list[int(h)] for h in hyp] ref_i = [self.token_list[int(r)] for r in refs[i]] char_hyp = "".join(hyp_i).replace(self.space, " ") char_hyp = char_hyp.replace(self.blank, "") char_ref = "".join(ref_i).replace(self.space, " ") char_hyps.append(char_hyp) char_refs.append(char_ref) return char_hyps, char_refs
[docs] def calculate_cer(self, hyps: torch.Tensor, refs: torch.Tensor) -> float: """Calculate sentence-level CER score. Args: hyps: Hypotheses sequences. (B, L) refs: References sequences. (B, L) Returns: : Average sentence-level CER score. """ import editdistance distances, lens = [], [] for i, hyp in enumerate(hyps): char_hyp = hyp.replace(" ", "") char_ref = refs[i].replace(" ", "") distances.append(editdistance.eval(char_hyp, char_ref)) lens.append(len(char_ref)) return float(sum(distances)) / sum(lens)
[docs] def calculate_wer(self, hyps: torch.Tensor, refs: torch.Tensor) -> float: """Calculate sentence-level WER score. Args: hyps: Hypotheses sequences. (B, L) refs: References sequences. (B, L) Returns: : Average sentence-level WER score. """ import editdistance distances, lens = [], [] for i, hyp in enumerate(hyps): word_hyp = hyp.split() word_ref = refs[i].split() distances.append(editdistance.eval(word_hyp, word_ref)) lens.append(len(word_ref)) return float(sum(distances)) / sum(lens)