import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from typeguard import typechecked
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.encoder.conformer_encoder import ConformerEncoder
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.s2st.aux_attention.abs_aux_attention import AbsS2STAuxAttention
from espnet2.s2st.losses.abs_loss import AbsS2STLoss
from espnet2.s2st.synthesizer.abs_synthesizer import AbsSynthesizer
from espnet2.s2st.tgt_feats_extract.abs_tgt_feats_extract import AbsTgtFeatsExtract
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator
from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class ESPnetS2STModel(AbsESPnetModel):
"""ESPnet speech-to-speech translation model"""
@typechecked
def __init__(
self,
s2st_type: str,
frontend: Optional[AbsFrontend],
tgt_feats_extract: Optional[AbsTgtFeatsExtract],
specaug: Optional[AbsSpecAug],
src_normalize: Optional[AbsNormalize],
tgt_normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
asr_decoder: Optional[AbsDecoder],
st_decoder: Optional[AbsDecoder],
aux_attention: Optional[AbsS2STAuxAttention],
unit_encoder: Optional[AbsEncoder],
synthesizer: Optional[AbsSynthesizer],
asr_ctc: Optional[CTC],
st_ctc: Optional[CTC],
losses: Dict[str, AbsS2STLoss],
tgt_vocab_size: Optional[int],
tgt_token_list: Optional[Union[Tuple[str, ...], List[str]]],
src_vocab_size: Optional[int],
src_token_list: Optional[Union[Tuple[str, ...], List[str]]],
unit_vocab_size: Optional[int],
unit_token_list: Optional[Union[Tuple[str, ...], List[str]]],
ignore_id: int = -1,
report_cer: bool = True,
report_wer: bool = True,
report_bleu: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
):
super().__init__()
self.sos = tgt_vocab_size - 1 if tgt_vocab_size else None
self.eos = tgt_vocab_size - 1 if tgt_vocab_size else None
self.src_sos = src_vocab_size - 1 if src_vocab_size else None
self.src_eos = src_vocab_size - 1 if src_vocab_size else None
self.unit_sos = unit_vocab_size - 1 if unit_vocab_size else None
self.unit_eos = unit_vocab_size - 1 if unit_vocab_size else None
self.tgt_vocab_size = tgt_vocab_size
self.src_vocab_size = src_vocab_size
self.unit_vocab_size = unit_vocab_size
self.ignore_id = ignore_id
self.tgt_token_list = tgt_token_list.copy() if tgt_token_list else None
self.src_token_list = src_token_list.copy() if src_token_list else None
self.unit_token_list = unit_token_list.copy() if unit_token_list else None
self.s2st_type = s2st_type
self.frontend = frontend
self.tgt_feats_extract = tgt_feats_extract
self.specaug = specaug
self.src_normalize = src_normalize
self.tgt_normalize = tgt_normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.encoder = encoder
self.asr_decoder = asr_decoder
self.st_decoder = st_decoder
self.aux_attention = aux_attention
self.unit_encoder = unit_encoder
self.synthesizer = synthesizer
self.asr_ctc = asr_ctc
self.st_ctc = st_ctc
self.losses = torch.nn.ModuleDict(losses)
# ST error calculator
if st_decoder and tgt_vocab_size and report_bleu:
self.mt_error_calculator = MTErrorCalculator(
tgt_token_list, sym_space, sym_blank, report_bleu
)
else:
self.mt_error_calculator = None
# ASR error calculator
if asr_decoder and src_vocab_size and (report_cer or report_wer):
assert (
src_token_list is not None
), "Missing src_token_list, cannot add asr module to st model"
self.asr_error_calculator = ASRErrorCalculator(
src_token_list, sym_space, sym_blank, report_cer, report_wer
)
else:
self.asr_error_calculator = None
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
if self.s2st_type == "discrete_unit":
assert isinstance(self.encoder, ConformerEncoder) or isinstance(
self.encoder, TransformerEncoder
), "only support conformer or transformer-based encoder now"
# synthesizer
assert (
"synthesis" in self.losses
), "must have synthesis loss in the losses for S2ST"
[docs] def forward(
self,
src_speech: torch.Tensor,
src_speech_lengths: torch.Tensor,
tgt_speech: torch.Tensor,
tgt_speech_lengths: torch.Tensor,
tgt_text: Optional[torch.Tensor] = None,
tgt_text_lengths: Optional[torch.Tensor] = None,
src_text: Optional[torch.Tensor] = None,
src_text_lengths: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None, # TODO(Jiatong)
sids: Optional[torch.Tensor] = None, # TODO(Jiatong)
lids: Optional[torch.Tensor] = None, # TODO(Jiatong)
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
# TODO(jiatong): add comments etc.
assert (
src_speech.shape[0]
== src_speech_lengths.shape[0]
== tgt_speech.shape[0]
== tgt_speech_lengths.shape[0]
), (
src_speech.shape,
src_speech_lengths.shape,
tgt_speech.shape,
tgt_speech_lengths.shape,
)
# additional checks with valid tgt_text and src_text
if tgt_text is not None:
assert tgt_text_lengths.dim() == 1, tgt_text_lengths.shape
assert (
src_speech.shape[0]
== src_text.shape[0]
== src_text_lengths.shape[0]
== tgt_text.shape[0]
== tgt_text_lengths.shape[0]
), (
src_speech.shape,
src_text.shape,
src_text_lengths.shape,
tgt_text.shape,
tgt_text_lengths.shape,
)
batch_size = src_speech.shape[0]
# for data-parallel
src_speech = src_speech[:, : src_speech_lengths.max()]
if src_text is not None:
src_text = src_text[:, : src_text_lengths.max()]
if tgt_text is not None:
tgt_text = tgt_text[:, : tgt_text_lengths.max()]
tgt_speech = tgt_speech[:, : tgt_speech_lengths.max()]
# 0. Target feature extract
# NOTE(jiatong): only for teaching-forcing in spectrogram
if self.tgt_feats_extract is not None:
tgt_feats, tgt_feats_lengths = self._extract_feats(
tgt_speech, tgt_speech_lengths, target=True
)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.tgt_normalize is not None:
tgt_feats, tgt_feats_lengths = self.tgt_normalize(
tgt_feats, tgt_feats_lengths
)
else:
# NOTE(jiatong): for discrete unit case
tgt_feats, tgt_feats_lengths = tgt_speech, tgt_speech_lengths
# 1. Encoder
if self.s2st_type == "discrete_unit":
(encoder_out, inter_encoder_out), encoder_out_lens = self.encode(
src_speech, src_speech_lengths, return_all_hs=True
)
else:
encoder_out, encoder_out_lens = self.encode(src_speech, src_speech_lengths)
loss_record = []
########################
# Translaotron Forward #
########################
if self.s2st_type == "translatotron":
# use a shared encoder with three decoders (i.e., asr, st, s2st)
# reference https://arxiv.org/pdf/1904.06037.pdf
# asr_ctc
if self.asr_ctc is not None and "asr_ctc" in self.losses:
asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss(
encoder_out,
encoder_out_lens,
src_text,
src_text_lengths,
ctc_type="asr",
)
loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight)
else:
asr_ctc_loss, cer_asr_ctc = None, None
# asr decoder
if self.asr_decoder is not None and "src_attn" in self.losses:
(
src_attn_loss,
acc_src_attn,
cer_src_attn,
wer_src_attn,
) = self._calc_asr_att_loss(
encoder_out, encoder_out_lens, src_text, src_text_lengths
)
loss_record.append(src_attn_loss * self.losses["src_attn"].weight)
else:
src_attn_loss, acc_src_attn, cer_src_attn, wer_src_attn = (
None,
None,
None,
None,
)
# st decoder
if self.st_decoder is not None and "tgt_attn" in self.losses:
tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn = self._calc_st_att_loss(
encoder_out, encoder_out_lens, tgt_text, tgt_text_lengths
)
loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight)
else:
tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn = None, None, None
# NOTE(jiatong): the tgt_feats is also updated based on the reduction_factor
(
after_outs,
before_outs,
logits,
att_ws,
updated_tgt_feats,
stop_labels,
updated_tgt_feats_lengths,
) = self.synthesizer(
encoder_out,
encoder_out_lens,
tgt_feats,
tgt_feats_lengths,
spembs,
sids,
lids,
)
syn_loss, l1_loss, mse_loss, bce_loss = self.losses["synthesis"](
after_outs,
before_outs,
logits,
updated_tgt_feats,
stop_labels,
updated_tgt_feats_lengths,
)
loss_record.append(syn_loss * self.losses["synthesis"].weight)
# NOTE(jiatong): guided attention will be not used in multi-head attention
if (
"syn_guided_attn" in self.losses
and self.synthesizer.atype != "multihead"
):
# NOTE(kan-bayashi): length of output for auto-regressive
# input will be changed when r > 1
if self.synthesizer.reduction_factor > 1:
updated_tgt_feats_lengths_in = updated_tgt_feats_lengths.new(
[
olen // self.reduction_factor
for olen in updated_tgt_feats_lengths
]
)
else:
updated_tgt_feats_lengths_in = updated_tgt_feats_lengths
syn_guided_attn_loss = self.losses["syn_guided_attn"](
att_ws=att_ws,
ilens=encoder_out_lens,
olens_in=updated_tgt_feats_lengths_in,
)
loss_record.append(
syn_guided_attn_loss * self.losses["syn_guided_attn"].weight
)
else:
syn_guided_attn_loss = None
loss = sum(loss_record)
stats = dict(
loss=loss.item(),
asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None,
cer_asr_ctc=cer_asr_ctc,
src_attn_loss=(
src_attn_loss.item() if src_attn_loss is not None else None
),
acc_src_attn=acc_src_attn,
cer_src_attn=cer_src_attn,
wer_src_attn=wer_src_attn,
tgt_attn_loss=(
tgt_attn_loss.item() if tgt_attn_loss is not None else None
),
acc_tgt_attn=acc_tgt_attn,
bleu_tgt_attn=bleu_tgt_attn,
syn_loss=syn_loss.item() if syn_loss is not None else None,
syn_guided_attn_loss=(
syn_guided_attn_loss.item()
if syn_guided_attn_loss is not None
else None
),
syn_l1_loss=l1_loss.item(),
syn_mse_loss=mse_loss.item(),
syn_bce_loss=bce_loss.item(),
)
#########################
# Translaotron2 Forward #
#########################
elif self.s2st_type == "translatotron2":
# use a sinlge decoder for synthesis
# reference https://arxiv.org/pdf/2107.08661v5.pdf
# asr_ctc
if self.asr_ctc is not None and "asr_ctc" in self.losses:
asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss(
encoder_out,
encoder_out_lens,
src_text,
src_text_lengths,
ctc_type="asr",
)
loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight)
else:
asr_ctc_loss, cer_asr_ctc = None, None
# st decoder
if self.st_decoder is not None and "tgt_attn" in self.losses:
(
tgt_attn_loss,
acc_tgt_attn,
bleu_tgt_attn,
decoder_out,
_,
) = self._calc_st_att_loss(
encoder_out,
encoder_out_lens,
tgt_text,
tgt_text_lengths,
return_hs=True,
)
loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight)
else:
tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, decoder_out = (
None,
None,
None,
None,
)
assert (
self.aux_attention is not None
), "must have aux attention in translatotron loss"
# NOTE(jiatong): tgt_text_lengths + 1 for <eos>
encoder_out_mask = (
make_pad_mask(encoder_out_lens).to(encoder_out.device).unsqueeze(1)
)
attention_out = self.aux_attention(
decoder_out, encoder_out, encoder_out, mask=encoder_out_mask
)
decoder_out = torch.cat((decoder_out, attention_out), dim=-1)
# NOTE(jiatong): the tgt_feats is also updated based on the reduction_factor
# TODO(jiatong): use non-attentive tacotron-based synthesizer
(
after_outs,
before_outs,
logits,
att_ws,
updated_tgt_feats,
stop_labels,
updated_tgt_feats_lengths,
) = self.synthesizer(
decoder_out,
tgt_text_lengths + 1, # NOTE(jiatong): +1 for <eos>
tgt_feats,
tgt_feats_lengths,
spembs,
sids,
lids,
)
syn_loss, l1_loss, mse_loss, bce_loss = self.losses["synthesis"](
after_outs,
before_outs,
logits,
updated_tgt_feats,
stop_labels,
updated_tgt_feats_lengths,
)
# loss_record.append(syn_loss * self.losses["synthesis"].weight)
loss = sum(loss_record)
stats = dict(
loss=loss.item(),
asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None,
cer_asr_ctc=cer_asr_ctc,
tgt_attn_loss=(
tgt_attn_loss.item() if tgt_attn_loss is not None else None
),
acc_tgt_attn=acc_tgt_attn,
bleu_tgt_attn=bleu_tgt_attn,
syn_loss=syn_loss.item() if syn_loss is not None else None,
syn_l1_loss=l1_loss.item() if l1_loss is not None else None,
syn_mse_loss=mse_loss.item() if mse_loss is not None else None,
syn_bce_loss=bce_loss.item() if bce_loss is not None else None,
)
#########################
# Discrete unit Forward #
#########################
elif self.s2st_type == "discrete_unit":
# discrete unit-based synthesis
# Reference: https://arxiv.org/pdf/2107.05604.pdf
encoder_layer_for_asr = len(inter_encoder_out) // 2
encoder_layer_for_st = len(inter_encoder_out) * 2 // 3
# asr_ctc
if self.asr_ctc is not None and "asr_ctc" in self.losses:
asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss(
inter_encoder_out[encoder_layer_for_asr],
encoder_out_lens,
src_text,
src_text_lengths,
ctc_type="asr",
)
loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight)
else:
asr_ctc_loss, cer_asr_ctc = None, None
# asr decoder
if self.asr_decoder is not None and "src_attn" in self.losses:
(
src_attn_loss,
acc_src_attn,
cer_src_attn,
wer_src_attn,
) = self._calc_asr_att_loss(
encoder_out, encoder_out_lens, src_text, src_text_lengths
)
loss_record.append(src_attn_loss * self.losses["src_attn"].weight)
else:
src_attn_loss, acc_src_attn, cer_src_attn, wer_src_attn = (
None,
None,
None,
None,
)
# st decoder
if self.st_decoder is not None and "tgt_attn" in self.losses:
(
tgt_attn_loss,
acc_tgt_attn,
bleu_tgt_attn,
) = self._calc_st_att_loss(
inter_encoder_out[encoder_layer_for_st],
encoder_out_lens,
tgt_text,
tgt_text_lengths,
)
loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight)
else:
tgt_attn_loss, acc_tgt_attn, bleu_tgt_attn, decoder_out = (
None,
None,
None,
None,
)
# synthesizer
(
unit_attn_loss,
acc_unit_attn,
syn_hidden,
syn_hidden_lengths,
) = self._calc_unit_att_loss(
encoder_out,
encoder_out_lens,
tgt_speech,
tgt_speech_lengths,
return_all_hs=True,
)
loss_record.append(unit_attn_loss * self.losses["synthesis"].weight)
unit_decoder_layer_for_st = len(syn_hidden) // 2
# st_ctc
if self.st_ctc is not None and "st_ctc" in self.losses:
st_ctc_loss, cer_st_ctc = self._calc_ctc_loss(
syn_hidden[unit_decoder_layer_for_st],
tgt_speech_lengths + 1,
tgt_text,
tgt_text_lengths,
ctc_type="st",
)
loss_record.append(st_ctc_loss * self.losses["st_ctc"].weight)
else:
st_ctc_loss, cer_st_ctc = None, None
loss = sum(loss_record)
stats = dict(
loss=loss.item(),
asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None,
cer_asr_ctc=cer_asr_ctc,
src_attn_loss=(
src_attn_loss.item() if src_attn_loss is not None else None
),
acc_src_attn=acc_src_attn,
cer_src_attn=cer_src_attn,
wer_src_attn=wer_src_attn,
tgt_attn_loss=(
tgt_attn_loss.item() if tgt_attn_loss is not None else None
),
acc_tgt_attn=acc_tgt_attn,
bleu_tgt_attn=bleu_tgt_attn,
st_ctc_loss=st_ctc_loss.item() if st_ctc_loss is not None else None,
cer_st_ctc=cer_st_ctc,
unit_attn_loss=(
unit_attn_loss.item() if unit_attn_loss is not None else None
),
acc_unit_attn=acc_unit_attn if acc_unit_attn is not None else None,
)
#################
# Unity Forward #
#################
elif self.s2st_type == "unity":
# unity
# Reference: https://arxiv.org/pdf/2212.08055.pdf
# asr_ctc
if self.asr_ctc is not None and "asr_ctc" in self.losses:
asr_ctc_loss, cer_asr_ctc = self._calc_ctc_loss(
encoder_out,
encoder_out_lens,
src_text,
src_text_lengths,
ctc_type="asr",
)
loss_record.append(asr_ctc_loss * self.losses["asr_ctc"].weight)
else:
asr_ctc_loss, cer_asr_ctc = None, None
# st decoder
assert (
self.st_decoder is not None and "tgt_attn" in self.losses
), "st_decoder is necessary for unity-based model"
(
tgt_attn_loss,
acc_tgt_attn,
bleu_tgt_attn,
decoder_out,
_,
) = self._calc_st_att_loss(
encoder_out,
encoder_out_lens,
tgt_text,
tgt_text_lengths,
return_hs=True,
)
loss_record.append(tgt_attn_loss * self.losses["tgt_attn"].weight)
assert (
self.unit_encoder is not None
), "unit_encoder is necessary for unity-based model"
unit_encoder_out, unit_encoder_out_lengths, _ = self.unit_encoder(
decoder_out, tgt_text_lengths + 1
)
# synthesizer
unit_attn_loss, acc_unit_attn = self._calc_unit_att_loss(
unit_encoder_out,
unit_encoder_out_lengths,
tgt_speech,
tgt_speech_lengths,
)
loss_record.append(unit_attn_loss * self.losses["synthesis"].weight)
loss = sum(loss_record)
stats = dict(
loss=loss.item(),
asr_ctc_loss=asr_ctc_loss.item() if asr_ctc_loss is not None else None,
cer_asr_ctc=cer_asr_ctc,
tgt_attn_loss=(
tgt_attn_loss.item() if tgt_attn_loss is not None else None
),
acc_tgt_attn=acc_tgt_attn,
bleu_tgt_attn=bleu_tgt_attn,
unit_attn_loss=(
unit_attn_loss.item() if unit_attn_loss is not None else None
),
acc_unit_attn=acc_unit_attn if acc_unit_attn is not None else None,
)
else:
raise ValueError("Not supported s2st type {}")
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] @typechecked
def inference(
self,
src_speech: torch.Tensor,
src_speech_lengths: Optional[torch.Tensor] = None,
tgt_speech: Optional[torch.Tensor] = None,
tgt_speech_lengths: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None, # TODO(Jiatong)
sids: Optional[torch.Tensor] = None, # TODO(Jiatong)
lids: Optional[torch.Tensor] = None, # TODO(Jiatong)
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 10.0,
use_att_constraint: bool = False,
backward_window: int = 1,
forward_window: int = 3,
use_teacher_forcing: bool = False,
) -> Dict[str, torch.Tensor]:
# 0. Target feature extract
# NOTE(jiatong): only for teaching-forcing in spectrogram
if tgt_speech is not None and self.tgt_feats_extract is not None:
tgt_feats, tgt_feats_lengths = self._extract_feats(
tgt_speech, tgt_speech_lengths, target=True
)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.tgt_normalize is not None:
tgt_feats, tgt_feats_lengths = self.tgt_normalize(
tgt_feats, tgt_feats_lengths
)
else:
# NOTE(jiatong): for discrete unit case
tgt_feats, tgt_feats_lengths = tgt_speech, tgt_speech_lengths
# 1. Encoder
encoder_out, _ = self.encode(src_speech, src_speech_lengths)
# 2. Decoder
if self.s2st_type == "translatotron":
assert encoder_out.size(0) == 1
output_dict = self.synthesizer.inference(
encoder_out[0],
tgt_feats[0],
spembs,
sids,
lids,
threshold,
minlenratio,
maxlenratio,
use_att_constraint,
backward_window,
forward_window,
use_teacher_forcing,
)
elif self.s2st_type == "translatotron2":
assert encoder_out.size(0) == 1
output_dict = self.synthesizer.inference(
encoder_out[0],
tgt_feats[0],
spembs,
sids,
lids,
threshold,
minlenratio,
maxlenratio,
use_att_constraint,
backward_window,
forward_window,
use_teacher_forcing,
)
else:
raise ValueError("Not supported s2st type {}")
if self.tgt_normalize is not None and output_dict.get("feat_gen") is not None:
# NOTE: normalize.inverse is in-place operation
feat_gen_denorm = self.tgt_normalize.inverse(
output_dict["feat_gen"].clone()[None]
)[0][0]
output_dict.update(feat_gen_denorm=feat_gen_denorm)
return output_dict
[docs] def collect_feats(
self,
src_speech: torch.Tensor,
src_speech_lengths: torch.Tensor,
tgt_speech: torch.Tensor,
tgt_speech_lengths: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
src_feats, src_feats_lengths = self._extract_feats(
src_speech, src_speech_lengths
)
return_dict = {
"src_feats": src_feats,
"src_feats_lengths": src_feats_lengths,
}
if self.tgt_feats_extract is not None:
tgt_feats, tgt_feats_lengths = self._extract_feats(
tgt_speech, tgt_speech_lengths, target=True
)
return_dict.update(
tgt_feats=tgt_feats,
tgt_feats_lengths=tgt_feats_lengths,
)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
return_dict = {
"src_feats": src_speech,
"tgt_feats": tgt_speech,
"src_feats_lengths": src_speech_lengths,
"tgt_feats_lengths": tgt_speech_lengths,
}
return return_dict
[docs] def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
return_all_hs: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by st_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.src_normalize is not None:
feats, feats_lengths = self.src_normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, return_all_hs=return_all_hs
)
if return_all_hs:
encoder_out, inter_encoder_out = encoder_out
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if return_all_hs:
return (encoder_out, inter_encoder_out), encoder_out_lens
else:
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, target: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if speech_lengths is not None:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
if target:
feats, feats_lengths = self.tgt_feats_extract(speech, speech_lengths)
else:
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def _calc_unit_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
spembs: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
return_hs: bool = False,
return_all_hs: bool = False,
):
ys_in_pad, ys_out_pad = add_sos_eos(
ys_pad, self.unit_sos, self.unit_eos, self.ignore_id
)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
(
decoder_outs,
decoder_out_lengths,
) = self.synthesizer(
encoder_out,
encoder_out_lens,
ys_in_pad,
ys_in_lens,
spembs,
sids,
lids,
return_hs=return_hs,
return_all_hs=return_all_hs,
)
if return_hs or return_all_hs:
(decoder_out, decoder_hidden) = decoder_outs
else:
decoder_out = decoder_outs
# 2. Compute attention loss
loss_att = self.losses["synthesis"](decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.unit_vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
if return_hs or return_all_hs:
return loss_att, acc_att, decoder_hidden, decoder_out_lengths
else:
return loss_att, acc_att
def _calc_st_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
return_hs: bool = False,
return_all_hs: bool = False,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
assert (
not return_hs or not return_all_hs
), "cannot return both last hiddens or all hiddens"
# 1. Forward decoder
(
decoder_outs,
decoder_out_lengths,
) = self.st_decoder(
encoder_out,
encoder_out_lens,
ys_in_pad,
ys_in_lens,
return_hs=return_hs,
return_all_hs=return_all_hs,
)
if return_hs or return_all_hs:
(decoder_out, decoder_hidden) = decoder_outs
else:
decoder_out = decoder_outs
# 2. Compute attention loss
loss_att = self.losses["tgt_attn"](decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.tgt_vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.mt_error_calculator is None:
bleu_att = None
else:
ys_hat = decoder_out.argmax(dim=-1)
bleu_att = self.mt_error_calculator(ys_hat.cpu(), ys_pad.cpu())
if return_hs:
return loss_att, acc_att, bleu_att, decoder_hidden, decoder_out_lengths
else:
return loss_att, acc_att, bleu_att
def _calc_asr_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(
ys_pad, self.src_sos, self.src_eos, self.ignore_id
)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.asr_decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.losses["src_attn"](decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.src_vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.asr_error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.asr_error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
ctc_type: str,
):
if ctc_type == "asr":
ctc = self.asr_ctc
elif ctc_type == "st":
ctc = self.st_ctc
else:
raise RuntimeError(
"Cannot recognize the ctc-type: need 'src'/'tgt', but found {}".format(
ctc_type
)
)
# Calc CTC loss
loss_ctc = ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.asr_error_calculator is not None:
ys_hat = ctc.argmax(encoder_out).data
cer_ctc = self.asr_error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
@property
def require_vocoder(self):
"""Return whether or not vocoder is required."""
return True