Source code for espnet.nets.pytorch_backend.e2e_asr_transducer

"""Transducer speech recognition model (pytorch)."""

import logging
import math
from argparse import ArgumentParser, Namespace
from dataclasses import asdict
from typing import List

import chainer
import numpy
import torch

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.beam_search_transducer import BeamSearchTransducer
from espnet.nets.pytorch_backend.nets_utils import get_subsample, make_non_pad_mask
from espnet.nets.pytorch_backend.transducer.arguments import (  # noqa: H301
    add_auxiliary_task_arguments,
    add_custom_decoder_arguments,
    add_custom_encoder_arguments,
    add_custom_training_arguments,
    add_decoder_general_arguments,
    add_encoder_general_arguments,
    add_rnn_decoder_arguments,
    add_rnn_encoder_arguments,
    add_transducer_arguments,
)
from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder
from espnet.nets.pytorch_backend.transducer.custom_encoder import CustomEncoder
from espnet.nets.pytorch_backend.transducer.error_calculator import ErrorCalculator
from espnet.nets.pytorch_backend.transducer.initializer import initializer
from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder
from espnet.nets.pytorch_backend.transducer.rnn_encoder import encoder_for
from espnet.nets.pytorch_backend.transducer.transducer_tasks import TransducerTasks
from espnet.nets.pytorch_backend.transducer.utils import (
    get_decoder_input,
    valid_aux_encoder_output_layers,
)
from espnet.nets.pytorch_backend.transformer.attention import (  # noqa: H301
    MultiHeadedAttention,
    RelPositionMultiHeadedAttention,
)
from espnet.nets.pytorch_backend.transformer.mask import target_mask
from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport
from espnet.utils.fill_missing_args import fill_missing_args


[docs]class Reporter(chainer.Chain): """A chainer reporter wrapper for Transducer models."""
[docs] def report( self, loss: float, loss_trans: float, loss_ctc: float, loss_aux_trans: float, loss_symm_kl_div: float, loss_lm: float, cer: float, wer: float, ): """Instantiate reporter attributes. Args: loss: Model loss. loss_trans: Main Transducer loss. loss_ctc: CTC loss. loss_aux_trans: Auxiliary Transducer loss. loss_symm_kl_div: Symmetric KL-divergence loss. loss_lm: Label smoothing loss. cer: Character Error Rate. wer: Word Error Rate. """ chainer.reporter.report({"loss": loss}, self) chainer.reporter.report({"loss_trans": loss_trans}, self) chainer.reporter.report({"loss_ctc": loss_ctc}, self) chainer.reporter.report({"loss_lm": loss_lm}, self) chainer.reporter.report({"loss_aux_trans": loss_aux_trans}, self) chainer.reporter.report({"loss_symm_kl_div": loss_symm_kl_div}, self) chainer.reporter.report({"cer": cer}, self) chainer.reporter.report({"wer": wer}, self) logging.info("loss:" + str(loss))
[docs]class E2E(ASRInterface, torch.nn.Module): """E2E module for Transducer models. Args: idim: Dimension of inputs. odim: Dimension of outputs. args: Namespace containing model options. ignore_id: Padding symbol ID. blank_id: Blank symbol ID. training: Whether the model is initialized in training or inference mode. """
[docs] @staticmethod def add_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for Transducer model.""" E2E.encoder_add_general_arguments(parser) E2E.encoder_add_rnn_arguments(parser) E2E.encoder_add_custom_arguments(parser) E2E.decoder_add_general_arguments(parser) E2E.decoder_add_rnn_arguments(parser) E2E.decoder_add_custom_arguments(parser) E2E.training_add_custom_arguments(parser) E2E.transducer_add_arguments(parser) E2E.auxiliary_task_add_arguments(parser) return parser
[docs] @staticmethod def encoder_add_general_arguments(parser: ArgumentParser) -> ArgumentParser: """Add general arguments for encoder.""" group = parser.add_argument_group("Encoder general arguments") group = add_encoder_general_arguments(group) return parser
[docs] @staticmethod def encoder_add_rnn_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for RNN encoder.""" group = parser.add_argument_group("RNN encoder arguments") group = add_rnn_encoder_arguments(group) return parser
[docs] @staticmethod def encoder_add_custom_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for Custom encoder.""" group = parser.add_argument_group("Custom encoder arguments") group = add_custom_encoder_arguments(group) return parser
[docs] @staticmethod def decoder_add_general_arguments(parser: ArgumentParser) -> ArgumentParser: """Add general arguments for decoder.""" group = parser.add_argument_group("Decoder general arguments") group = add_decoder_general_arguments(group) return parser
[docs] @staticmethod def decoder_add_rnn_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for RNN decoder.""" group = parser.add_argument_group("RNN decoder arguments") group = add_rnn_decoder_arguments(group) return parser
[docs] @staticmethod def decoder_add_custom_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for Custom decoder.""" group = parser.add_argument_group("Custom decoder arguments") group = add_custom_decoder_arguments(group) return parser
[docs] @staticmethod def training_add_custom_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for Custom architecture training.""" group = parser.add_argument_group("Training arguments for custom archictecture") group = add_custom_training_arguments(group) return parser
[docs] @staticmethod def transducer_add_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for Transducer model.""" group = parser.add_argument_group("Transducer model arguments") group = add_transducer_arguments(group) return parser
[docs] @staticmethod def auxiliary_task_add_arguments(parser: ArgumentParser) -> ArgumentParser: """Add arguments for auxiliary task.""" group = parser.add_argument_group("Auxiliary task arguments") group = add_auxiliary_task_arguments(group) return parser
@property def attention_plot_class(self): """Get attention plot class.""" return PlotAttentionReport
[docs] def get_total_subsampling_factor(self) -> float: """Get total subsampling factor.""" if self.etype == "custom": return self.encoder.conv_subsampling_factor * int( numpy.prod(self.subsample) ) else: return self.enc.conv_subsampling_factor * int(numpy.prod(self.subsample))
def __init__( self, idim: int, odim: int, args: Namespace, ignore_id: int = -1, blank_id: int = 0, training: bool = True, ): """Construct an E2E object for Transducer model.""" torch.nn.Module.__init__(self) args = fill_missing_args(args, self.add_arguments) self.is_transducer = True self.use_auxiliary_enc_outputs = ( True if (training and args.use_aux_transducer_loss) else False ) self.subsample = get_subsample( args, mode="asr", arch="transformer" if args.etype == "custom" else "rnn-t" ) if self.use_auxiliary_enc_outputs: n_layers = ( ((len(args.enc_block_arch) * args.enc_block_repeat) - 1) if args.enc_block_arch is not None else (args.elayers - 1) ) aux_enc_output_layers = valid_aux_encoder_output_layers( args.aux_transducer_loss_enc_output_layers, n_layers, args.use_symm_kl_div_loss, self.subsample, ) else: aux_enc_output_layers = [] if args.etype == "custom": if args.enc_block_arch is None: raise ValueError( "When specifying custom encoder type, --enc-block-arch" "should be set in training config." ) self.encoder = CustomEncoder( idim, args.enc_block_arch, args.custom_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.custom_enc_self_attn_type, positional_encoding_type=args.custom_enc_positional_encoding_type, positionwise_activation_type=args.custom_enc_pw_activation_type, conv_mod_activation_type=args.custom_enc_conv_mod_activation_type, aux_enc_output_layers=aux_enc_output_layers, input_layer_dropout_rate=args.custom_enc_input_dropout_rate, input_layer_pos_enc_dropout_rate=( args.custom_enc_input_pos_enc_dropout_rate ), ) encoder_out = self.encoder.enc_out else: self.enc = encoder_for( args, idim, self.subsample, aux_enc_output_layers=aux_enc_output_layers, ) encoder_out = args.eprojs if args.dtype == "custom": if args.dec_block_arch is None: raise ValueError( "When specifying custom decoder type, --dec-block-arch" "should be set in training config." ) self.decoder = CustomDecoder( odim, args.dec_block_arch, args.custom_dec_input_layer, repeat_block=args.dec_block_repeat, positionwise_activation_type=args.custom_dec_pw_activation_type, input_layer_dropout_rate=args.dropout_rate_embed_decoder, blank_id=blank_id, ) decoder_out = self.decoder.dunits else: self.dec = RNNDecoder( odim, args.dtype, args.dlayers, args.dunits, args.dec_embed_dim, dropout_rate=args.dropout_rate_decoder, dropout_rate_embed=args.dropout_rate_embed_decoder, blank_id=blank_id, ) decoder_out = args.dunits self.transducer_tasks = TransducerTasks( encoder_out, decoder_out, args.joint_dim, odim, joint_activation_type=args.joint_activation_type, transducer_loss_weight=args.transducer_weight, ctc_loss=args.use_ctc_loss, ctc_loss_weight=args.ctc_loss_weight, ctc_loss_dropout_rate=args.ctc_loss_dropout_rate, lm_loss=args.use_lm_loss, lm_loss_weight=args.lm_loss_weight, lm_loss_smoothing_rate=args.lm_loss_smoothing_rate, aux_transducer_loss=args.use_aux_transducer_loss, aux_transducer_loss_weight=args.aux_transducer_loss_weight, aux_transducer_loss_mlp_dim=args.aux_transducer_loss_mlp_dim, aux_trans_loss_mlp_dropout_rate=args.aux_transducer_loss_mlp_dropout_rate, symm_kl_div_loss=args.use_symm_kl_div_loss, symm_kl_div_loss_weight=args.symm_kl_div_loss_weight, fastemit_lambda=args.fastemit_lambda, blank_id=blank_id, ignore_id=ignore_id, training=training, ) if training and (args.report_cer or args.report_wer): self.error_calculator = ErrorCalculator( self.decoder if args.dtype == "custom" else self.dec, self.transducer_tasks.joint_network, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.etype = args.etype self.dtype = args.dtype self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() self.default_parameters(args) self.loss = None self.rnnlm = None
[docs] def default_parameters(self, args: Namespace): """Initialize/reset parameters for Transducer. Args: args: Namespace containing model options. """ initializer(self, args)
[docs] def forward( self, feats: torch.Tensor, feats_len: torch.Tensor, labels: torch.Tensor ) -> torch.Tensor: """E2E forward. Args: feats: Feature sequences. (B, F, D_feats) feats_len: Feature sequences lengths. (B,) labels: Label ID sequences. (B, L) Returns: loss: Transducer loss value """ # 1. encoder feats = feats[:, : max(feats_len)] if self.etype == "custom": feats_mask = ( make_non_pad_mask(feats_len.tolist()).to(feats.device).unsqueeze(-2) ) _enc_out, _enc_out_len = self.encoder(feats, feats_mask) else: _enc_out, _enc_out_len, _ = self.enc(feats, feats_len) if self.use_auxiliary_enc_outputs: enc_out, aux_enc_out = _enc_out[0], _enc_out[1] enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1] else: enc_out, aux_enc_out = _enc_out, None enc_out_len, aux_enc_out_len = _enc_out_len, None # 2. decoder dec_in = get_decoder_input(labels, self.blank_id, self.ignore_id) if self.dtype == "custom": self.decoder.set_device(enc_out.device) dec_in_mask = target_mask(dec_in, self.blank_id) dec_out, _ = self.decoder(dec_in, dec_in_mask) else: self.dec.set_device(enc_out.device) dec_out = self.dec(dec_in) # 3. Transducer task and auxiliary tasks computation losses = self.transducer_tasks( enc_out, aux_enc_out, dec_out, labels, enc_out_len, aux_enc_out_len, ) if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator( enc_out, self.transducer_tasks.get_target() ) self.loss = sum(losses) loss_data = float(self.loss) if not math.isnan(loss_data): self.reporter.report( loss_data, *[float(loss) for loss in losses], cer, wer, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
[docs] def encode_custom(self, feats: numpy.ndarray) -> torch.Tensor: """Encode acoustic features. Args: feats: Feature sequence. (F, D_feats) Returns: enc_out: Encoded feature sequence. (T, D_enc) """ feats = torch.as_tensor(feats).unsqueeze(0) enc_out, _ = self.encoder(feats, None) return enc_out.squeeze(0)
[docs] def encode_rnn(self, feats: numpy.ndarray) -> torch.Tensor: """Encode acoustic features. Args: feats: Feature sequence. (F, D_feats) Returns: enc_out: Encoded feature sequence. (T, D_enc) """ p = next(self.parameters()) feats_len = [feats.shape[0]] feats = feats[:: self.subsample[0], :] feats = torch.as_tensor(feats, device=p.device, dtype=p.dtype) feats = feats.contiguous().unsqueeze(0) enc_out, _, _ = self.enc(feats, feats_len) return enc_out.squeeze(0)
[docs] def recognize( self, feats: numpy.ndarray, beam_search: BeamSearchTransducer ) -> List: """Recognize input features. Args: feats: Feature sequence. (F, D_feats) beam_search: Beam search class. Returns: nbest_hyps: N-best decoding results. """ self.eval() if self.etype == "custom": enc_out = self.encode_custom(feats) else: enc_out = self.encode_rnn(feats) nbest_hyps = beam_search(enc_out) return [asdict(n) for n in nbest_hyps]
[docs] def calculate_all_attentions( self, feats: torch.Tensor, feats_len: torch.Tensor, labels: torch.Tensor ) -> numpy.ndarray: """E2E attention calculation. Args: feats: Feature sequences. (B, F, D_feats) feats_len: Feature sequences lengths. (B,) labels: Label ID sequences. (B, L) Returns: ret: Attention weights with the following shape, 1) multi-head case => attention weights. (B, D_att, U, T), 2) other case => attention weights. (B, U, T) """ self.eval() if self.etype != "custom" and self.dtype != "custom": return [] else: with torch.no_grad(): self.forward(feats, feats_len, labels) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) or isinstance( m, RelPositionMultiHeadedAttention ): ret[name] = m.attn.cpu().numpy() self.train() return ret