import logging
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.cuda.amp import autocast
from typeguard import typechecked
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.nets_utils import pad_list, th_accuracy
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
[docs]class ESPnetS2TModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
@typechecked
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: Optional[AbsDecoder],
ctc: CTC,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
sym_sos: str = "<sos>",
sym_eos: str = "<eos>",
sym_sop: str = "<sop>", # start of prev
sym_na: str = "<na>", # not available
extract_feats_in_collect_stats: bool = True,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__()
self.blank_id = token_list.index(sym_blank)
self.sos = token_list.index(sym_sos)
self.eos = token_list.index(sym_eos)
self.sop = token_list.index(sym_sop)
self.na = token_list.index(sym_na)
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.interctc_weight = interctc_weight
self.token_list = token_list.copy()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.encoder = encoder
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.error_calculator = None
if ctc_weight < 1.0:
assert (
decoder is not None
), "decoder should not be None when attention is used"
else:
decoder = None
logging.warning("Set decoder to none as ctc_weight==1.0")
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
self.is_encoder_whisper = "Whisper" in type(self.encoder).__name__
if self.is_encoder_whisper:
assert (
self.frontend is None
), "frontend should be None when using full Whisper model"
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
text_prev: torch.Tensor,
text_prev_lengths: torch.Tensor,
text_ctc: torch.Tensor,
text_ctc_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
text_prev: (Batch, Length)
text_prev_lengths: (Batch,)
text_ctc: (Batch, Length)
text_ctc_lengths: (Batch,)
kwargs: "utt_id" is among the input.
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
== text_prev.shape[0]
== text_prev_lengths.shape[0]
== text_ctc.shape[0]
== text_ctc_lengths.shape[0]
), (
speech.shape,
speech_lengths.shape,
text.shape,
text_lengths.shape,
text_prev.shape,
text_prev_lengths.shape,
text_ctc.shape,
text_ctc_lengths.shape,
)
batch_size = speech.shape[0]
# -1 is used as padding index in collate fn
text[text == -1] = self.ignore_id
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text_ctc, text_ctc_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text_ctc, text_ctc_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 2. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out,
encoder_out_lens,
text,
text_lengths,
text_prev,
text_prev_lengths,
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
text_prev: torch.Tensor,
text_prev_lengths: torch.Tensor,
text_ctc: torch.Tensor,
text_ctc_lengths: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by s2t_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
if (
getattr(self.encoder, "selfattention_layer_type", None) != "lf_selfattn"
and not self.is_encoder_whisper
):
assert encoder_out.size(-2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
ys_prev_pad: torch.Tensor,
ys_prev_lens: torch.Tensor,
):
# 0. Prepare input and output with sos, eos, sop
ys = [y[y != self.ignore_id] for y in ys_pad]
ys_prev = [y[y != self.ignore_id] for y in ys_prev_pad]
_sos = ys_pad.new([self.sos])
_eos = ys_pad.new([self.eos])
_sop = ys_pad.new([self.sop])
ys_in = []
ys_in_lens = []
ys_out = []
for y_prev, y in zip(ys_prev, ys):
if self.na in y_prev:
# Prev is not available in this case
y_in = [_sos, y]
y_in_len = len(y) + 1
y_out = [y, _eos]
else:
y_in = [_sop, y_prev, _sos, y]
y_in_len = len(y_prev) + len(y) + 2
y_out = [self.ignore_id * ys_pad.new_ones(len(y_prev) + 1), y, _eos]
ys_in.append(torch.cat(y_in))
ys_in_lens.append(y_in_len)
ys_out.append(torch.cat(y_out))
ys_in_pad = pad_list(ys_in, self.eos)
ys_in_lens = torch.tensor(ys_in_lens).to(ys_pad_lens)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_out_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Filter out invalid samples where text is not available
is_valid = [self.na not in y for y in ys_pad]
if not any(is_valid):
return torch.tensor(0.0), None
encoder_out = encoder_out[is_valid]
encoder_out_lens = encoder_out_lens[is_valid]
ys_pad = ys_pad[is_valid]
ys_pad_lens = ys_pad_lens[is_valid]
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc