import logging
import random
from contextlib import contextmanager
from itertools import groupby
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from torch.nn.utils.rnn import pad_sequence
from typeguard import typechecked
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.asr_transducer.utils import get_transducer_task_io
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator
from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class ESPnetSTModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
@typechecked
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
hier_encoder: Optional[AbsEncoder],
md_encoder: Optional[AbsEncoder],
extra_mt_encoder: Optional[AbsEncoder],
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
extra_asr_decoder: Optional[AbsDecoder],
extra_mt_decoder: Optional[AbsDecoder],
ctc: Optional[CTC],
st_ctc: Optional[CTC],
st_joint_network: Optional[torch.nn.Module],
src_vocab_size: Optional[int],
src_token_list: Optional[Union[Tuple[str, ...], List[str]]],
asr_weight: float = 0.0,
mt_weight: float = 0.0,
mtlalpha: float = 0.0,
st_mtlalpha: float = 0.0,
ignore_id: int = -1,
tgt_ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
report_bleu: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
tgt_sym_space: str = "<space>",
tgt_sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
ctc_sample_rate: float = 0.0,
tgt_sym_sos: str = "<sos/eos>",
tgt_sym_eos: str = "<sos/eos>",
lang_token_id: int = -1,
):
assert 0.0 <= asr_weight < 1.0, "asr_weight should be [0.0, 1.0)"
assert 0.0 <= mt_weight < 1.0, "mt_weight should be [0.0, 1.0)"
assert 0.0 <= mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
super().__init__()
# note that eos is the same as sos (equivalent ID)
if tgt_sym_sos in token_list:
self.sos = token_list.index(tgt_sym_sos)
else:
self.sos = vocab_size - 1
if tgt_sym_eos in token_list:
self.eos = token_list.index(tgt_sym_eos)
else:
self.eos = vocab_size - 1
self.src_sos = src_vocab_size - 1 if src_vocab_size else None
self.src_eos = src_vocab_size - 1 if src_vocab_size else None
self.vocab_size = vocab_size
self.src_vocab_size = src_vocab_size
self.ignore_id = ignore_id
self.tgt_ignore_id = tgt_ignore_id
self.asr_weight = asr_weight
self.mt_weight = mt_weight
self.mtlalpha = mtlalpha
self.st_mtlalpha = st_mtlalpha
self.token_list = token_list.copy()
self.src_token_list = src_token_list.copy()
self.ctc_sample_rate = ctc_sample_rate
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.hier_encoder = hier_encoder
self.encoder = encoder
if self.st_mtlalpha < 1.0:
self.decoder = decoder
elif decoder is not None:
logging.warning(
"Not using decoder because "
"st_mtlalpha is set as {} (== 1.00)".format(st_mtlalpha),
)
self.md_encoder = md_encoder
self.st_use_transducer_decoder = st_joint_network is not None
if self.st_use_transducer_decoder:
from warprnnt_pytorch import RNNTLoss
self.st_joint_network = st_joint_network
if tgt_sym_blank in token_list:
self.blank_id = token_list.index(tgt_sym_blank)
else:
# OpenAI Whisper model doesn't <blank> token
self.blank_id = 0
self.st_criterion_transducer = RNNTLoss(
blank=self.blank_id,
fastemit_lambda=0.0,
)
else:
self.criterion_st = LabelSmoothingLoss(
size=vocab_size,
padding_idx=tgt_ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if self.st_mtlalpha > 0.0:
self.st_ctc = st_ctc
self.criterion_asr = LabelSmoothingLoss(
size=src_vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# submodule for ASR task
if self.asr_weight > 0:
assert (
src_token_list is not None
), "Missing src_token_list, cannot add asr module to st model"
if self.mtlalpha > 0.0:
self.ctc = ctc
if self.mtlalpha < 1.0:
self.extra_asr_decoder = extra_asr_decoder
elif extra_asr_decoder is not None:
logging.warning(
"Not using extra_asr_decoder because "
"mtlalpha is set as {} (== 1.0)".format(mtlalpha),
)
# submodule for MT task
# TODO(brian): this should be deprecated
if self.mt_weight > 0:
self.extra_mt_decoder = extra_mt_decoder
self.extra_mt_encoder = extra_mt_encoder
elif extra_mt_decoder is not None:
logging.warning(
"Not using extra_mt_decoder because "
"mt_weight is set as {} (== 0)".format(mt_weight),
)
# MT error calculator
if report_bleu:
self.mt_error_calculator = MTErrorCalculator(
token_list, tgt_sym_space, tgt_sym_blank, report_bleu
)
else:
self.mt_error_calculator = None
# ASR error calculator
if self.asr_weight > 0 and (report_cer or report_wer):
assert (
src_token_list is not None
), "Missing src_token_list, cannot add asr module to st model"
self.asr_error_calculator = ASRErrorCalculator(
src_token_list, sym_space, sym_blank, report_cer, report_wer
)
else:
self.asr_error_calculator = None
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
self.use_multidecoder = self.md_encoder is not None
if hasattr(self, "decoder"):
self.use_speech_attn = getattr(self.decoder, "use_speech_attn", False)
else:
self.use_speech_attn = None
# TODO(jiatong): add multilingual related functions
if lang_token_id != -1:
self.lang_token_id = torch.tensor([[lang_token_id]])
else:
self.lang_token_id = None
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
src_text: Optional[torch.Tensor] = None,
src_text_lengths: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
text: (Batch, Length)
text_lengths: (Batch,)
src_text: (Batch, length)
src_text_lengths: (Batch,)
kwargs: "utt_id" is among the input.
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
# additional checks with valid src_text
if src_text is not None:
assert src_text_lengths.dim() == 1, src_text_lengths.shape
assert text.shape[0] == src_text.shape[0] == src_text_lengths.shape[0], (
text.shape,
src_text.shape,
src_text_lengths.shape,
)
batch_size = speech.shape[0]
text[text == -1] = self.tgt_ignore_id
# for data-parallel
text = text[:, : text_lengths.max()]
if src_text is not None:
src_text = src_text[:, : src_text_lengths.max()]
# lang id for mbart
if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
text = torch.cat(
[
self.lang_token_id.repeat(text.size(0), 1).to(text.device),
text,
],
dim=1,
)
text_lengths += 1
# 1. Encoder
if self.hier_encoder is not None or (
self.postencoder is not None and self.postencoder.return_int_enc
):
(
st_encoder_out,
st_encoder_out_lens,
asr_encoder_out,
asr_encoder_out_lens,
) = self.encode(speech, speech_lengths, return_int_enc=True)
else:
st_encoder_out, st_encoder_out_lens = self.encode(speech, speech_lengths)
asr_encoder_out, asr_encoder_out_lens = st_encoder_out, st_encoder_out_lens
# 2a. CTC branch
if self.asr_weight > 0:
assert src_text is not None, "missing source text for asr sub-task of ST"
if self.asr_weight > 0 and self.mtlalpha > 0:
loss_asr_ctc, cer_asr_ctc = self._calc_asr_ctc_loss(
asr_encoder_out, asr_encoder_out_lens, src_text, src_text_lengths
)
else:
loss_asr_ctc, cer_asr_ctc = 0.0, None
if self.st_mtlalpha > 0:
if self.postencoder is not None and self.postencoder.return_int_enc:
# run ST CTC without post-encoder downsampling (no hier enc)
loss_st_ctc, bleu_st_ctc = self._calc_mt_ctc_loss(
asr_encoder_out, asr_encoder_out_lens, text, text_lengths
)
else:
loss_st_ctc, bleu_st_ctc = self._calc_mt_ctc_loss(
st_encoder_out, st_encoder_out_lens, text, text_lengths
)
else:
loss_st_ctc, bleu_st_ctc = 0.0, None
# 2b. Attention-decoder branch (extra ASR)
if self.asr_weight > 0 and self.mtlalpha < 1.0:
(
loss_asr_att,
acc_asr_att,
cer_asr_att,
wer_asr_att,
hs_dec_asr,
) = self._calc_asr_att_loss(
asr_encoder_out,
asr_encoder_out_lens,
src_text,
src_text_lengths,
self.use_multidecoder,
)
else:
loss_asr_att, acc_asr_att, cer_asr_att, wer_asr_att = 0.0, None, None, None
# 2c. Attention-decoder branch (extra MT)
if self.mt_weight > 0:
mt_encoder_out, mt_encoder_out_lens = self.extra_mt_encoder(
src_text, src_text_lengths
)
loss_mt_att, acc_mt_att, bleu_mt_att = self._calc_mt_att_loss(
mt_encoder_out,
mt_encoder_out_lens,
text,
text_lengths,
None,
None,
st=False, # uses same decoder as ST
)
# loss_mt_att, acc_mt_att = self._calc_mt_att_loss(
# st_encoder_out, st_encoder_out_lens, text, text_lengths, st=False
# )
else:
loss_mt_att, acc_mt_att = 0.0, None
# 2d. Multi-Decoder encoder
if self.use_speech_attn:
speech_out = st_encoder_out
speech_lens = st_encoder_out_lens
else:
speech_out = None
speech_lens = None
if self.use_multidecoder:
dec_asr_lengths = src_text_lengths + 1
st_encoder_out, st_encoder_out_lens, _ = self.md_encoder(
hs_dec_asr, dec_asr_lengths
)
st_ctc_weight = self.st_mtlalpha
if st_ctc_weight < 1.0:
if self.st_use_transducer_decoder:
# 2e. Transducer decoder branch
(
loss_st_trans,
_,
_,
) = self._calc_st_transducer_loss(
st_encoder_out,
st_encoder_out_lens,
text,
)
if st_ctc_weight == 1.0:
loss_st = loss_st_ctc
elif st_ctc_weight == 0.0:
loss_st = loss_st_trans
else:
loss_st = (
st_ctc_weight * loss_st_ctc
+ (1 - st_ctc_weight) * loss_st_trans
)
loss_st_att = 0.0
acc_st_att = None
bleu_st_att = None
else:
# 2e. Attention-decoder branch (ST)
loss_st_att, acc_st_att, bleu_st_att = self._calc_mt_att_loss(
st_encoder_out,
st_encoder_out_lens,
text,
text_lengths,
speech_out,
speech_lens,
st=True,
)
if st_ctc_weight == 1.0:
loss_st = loss_st_ctc
elif st_ctc_weight == 0.0:
loss_st = loss_st_att
else:
loss_st = (
st_ctc_weight * loss_st_ctc + (1 - st_ctc_weight) * loss_st_att
)
loss_st_trans = 0.0
else:
loss_st = loss_st_ctc
loss_st_att = 0.0
acc_st_att = None
bleu_st_att = None
loss_st_trans = 0.0
# 3. Loss computation
asr_ctc_weight = self.mtlalpha
if asr_ctc_weight == 1.0:
loss_asr = loss_asr_ctc
elif asr_ctc_weight == 0.0:
loss_asr = loss_asr_att
else:
loss_asr = (
asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att
)
loss_mt = self.mt_weight * loss_mt_att
loss = (
(1 - self.asr_weight - self.mt_weight) * loss_st
+ self.asr_weight * loss_asr
+ self.mt_weight * loss_mt
)
stats = dict(
loss=loss.detach(),
loss_asr=loss_asr.detach() if type(loss_asr) is not float else loss_asr,
loss_mt=loss_mt.detach() if type(loss_mt) is not float else loss_mt,
loss_st_ctc=(
loss_st_ctc.detach() if type(loss_st_ctc) is not float else loss_st_ctc
),
loss_st_trans=(
loss_st_trans.detach()
if type(loss_st_trans) is not float
else loss_st_trans
),
loss_st_att=(
loss_st_att.detach() if type(loss_st_att) is not float else loss_st_att
),
loss_st=loss_st.detach(),
acc_asr=acc_asr_att,
acc_mt=acc_mt_att,
acc=acc_st_att,
cer_ctc=cer_asr_ctc,
cer=cer_asr_att,
wer=wer_asr_att,
bleu=bleu_st_att,
bleu_ctc=bleu_st_ctc,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
src_text: Optional[torch.Tensor] = None,
src_text_lengths: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_int_enc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by st_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
if return_int_enc:
int_encoder_out, int_encoder_out_lens = encoder_out, encoder_out_lens
if self.hier_encoder is not None:
encoder_out, encoder_out_lens, _ = self.hier_encoder(
encoder_out, encoder_out_lens
)
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if return_int_enc:
return encoder_out, encoder_out_lens, int_encoder_out, int_encoder_out_lens
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def _calc_mt_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
speech: Optional[torch.Tensor],
speech_lens: Optional[torch.Tensor],
st: bool = True,
):
ys_in_pad, ys_out_pad = add_sos_eos(
ys_pad, self.sos, self.eos, self.tgt_ignore_id
)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
if st:
if self.use_speech_attn:
decoder_out, _ = self.decoder(
encoder_out,
encoder_out_lens,
ys_in_pad,
ys_in_lens,
speech,
speech_lens,
)
else:
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
else:
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_st(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.tgt_ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.mt_error_calculator is None:
bleu_att = None
else:
ys_hat = decoder_out.argmax(dim=-1)
bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, bleu_att
def _calc_asr_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
return_hs: bool = False,
):
# Use CTC output as AR decoder target; useful for multi-decoder training
skip_loss = False
if self.training and self.ctc_sample_rate > 0:
if random.uniform(0, 1) < self.ctc_sample_rate:
ys_hat = self.ctc.argmax(encoder_out).data
ys_hat = [[x[0] for x in groupby(ys)] for ys in ys_hat]
ys_hat = [[x for x in filter(lambda x: x != 0, ys)] for ys in ys_hat]
for i, ys in enumerate(ys_hat):
if len(ys) == 0:
ys_hat[i] = [x for x in ys_pad[i] if x != -1]
ys_pad_lens = torch.tensor(
[len(x) for x in ys_hat], device=encoder_out.device
)
ys_pad = [torch.tensor(ys, device=encoder_out.device) for ys in ys_pad]
ys_pad = pad_sequence(ys_pad, batch_first=True, padding_value=-1)
# skip the loss
skip_loss = True
ys_in_pad, ys_out_pad = add_sos_eos(
ys_pad, self.src_sos, self.src_eos, self.ignore_id
)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
if return_hs:
decoder_out, _, hs_dec_asr = self.extra_asr_decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, return_hs=True
)
else:
hs_dec_asr = None
decoder_out, _ = self.extra_asr_decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
if skip_loss:
return 0.0, None, None, None, hs_dec_asr
# 2. Compute attention loss
loss_att = self.criterion_asr(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.src_vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.asr_error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.asr_error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, hs_dec_asr
def _calc_asr_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.asr_error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.asr_error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def _calc_mt_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.st_ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
bleu_ctc = None
if not self.training and self.mt_error_calculator is not None:
ys_hat = self.st_ctc.argmax(encoder_out).data
bleu_ctc = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, bleu_ctc
def _calc_st_transducer_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
labels: torch.Tensor,
):
"""Compute Transducer loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
encoder_out_lens: Encoder output sequences lengths. (B,)
labels: Label ID sequences. (B, L)
Return:
loss_transducer: Transducer loss value.
cer_transducer: Character error rate for Transducer.
wer_transducer: Word Error Rate for Transducer.
"""
decoder_in, target, t_len, u_len = get_transducer_task_io(
labels,
encoder_out_lens,
ignore_id=self.tgt_ignore_id,
blank_id=self.blank_id,
)
self.decoder.set_device(encoder_out.device)
decoder_out = self.decoder(decoder_in)
joint_out = self.st_joint_network(
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
loss_transducer = self.st_criterion_transducer(
joint_out,
target,
t_len,
u_len,
)
cer_transducer, wer_transducer = None, None
# TODO(brian): add error_calculator_trans
return loss_transducer, cer_transducer, wer_transducer