Source code for espnet.nets.pytorch_backend.e2e_asr_mix_transformer

#!/usr/bin/env python3
# encoding: utf-8

# Copyright 2020 Johns Hopkins University (Xuankai Chang)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""
Transformer speech recognition model for single-channel multi-speaker mixture speech.

It is a fusion of `e2e_asr_mix.py` and `e2e_asr_transformer.py`. Refer to:
    https://arxiv.org/pdf/2002.03921.pdf
1. The Transformer-based Encoder now consists of three stages:
     (a): Enc_mix: encoding input mixture speech;
     (b): Enc_SD: separating mixed speech representations;
     (c): Enc_rec: transforming each separated speech representation.
2. PIT is used in CTC to determine the permutation with minimum loss.
"""

import logging
import math
from argparse import Namespace

import numpy
import torch

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.ctc_prefix_score import CTCPrefixScore
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.pytorch_backend.ctc import CTC
from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD
from espnet.nets.pytorch_backend.e2e_asr_mix import E2E as E2EASRMIX
from espnet.nets.pytorch_backend.e2e_asr_mix import PIT
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2EASR
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask, th_accuracy
from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.encoder_mix import EncoderMix
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask, target_mask


[docs]class E2E(E2EASR, ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """
[docs] @staticmethod def add_arguments(parser): """Add arguments.""" E2EASR.add_arguments(parser) E2EASRMIX.encoder_mix_add_arguments(parser) return parser
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__(idim, odim, args, ignore_id=-1) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = EncoderMix( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks_sd=args.elayers_sd, num_blocks_rec=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, num_spkrs=args.num_spkrs, ) if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=False ) else: self.ctc = None self.num_spkrs = args.num_spkrs self.pit = PIT(self.num_spkrs)
[docs] def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, num_spkrs, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # list: speaker differentiate self.hs_pad = hs_pad # 2. ctc # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None assert self.mtlalpha > 0.0 batch_size = xs_pad.size(0) ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) hs_len = [hs_mask[i].view(batch_size, -1).sum(1) for i in range(self.num_spkrs)] loss_ctc_perm = torch.stack( [ self.ctc( hs_pad[i // self.num_spkrs].view(batch_size, -1, self.adim), hs_len[i // self.num_spkrs], ys_pad[i % self.num_spkrs], ) for i in range(self.num_spkrs**2) ], dim=1, ) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) logging.info("ctc loss:" + str(float(loss_ctc))) # Permute the labels according to loss for b in range(batch_size): # B ys_pad[:, b] = ys_pad[min_perm[b], b] # (num_spkrs, B, Lmax) ys_out_len = [ float(torch.sum(ys_pad[i] != self.ignore_id)) for i in range(self.num_spkrs) ] # TODO(karita) show predicted text # TODO(karita) calculate these stats if self.error_calculator is not None: cer_ctc = [] for i in range(self.num_spkrs): ys_hat = self.ctc.argmax(hs_pad[i].view(batch_size, -1, self.adim)).data cer_ctc.append( self.error_calculator(ys_hat.cpu(), ys_pad[i].cpu(), is_ctc=True) ) cer_ctc = sum(map(lambda x: x[0] * x[1], zip(cer_ctc, ys_out_len))) / sum( ys_out_len ) else: cer_ctc = None # 3. forward decoder if self.mtlalpha == 1.0: loss_att, self.acc, cer, wer = None, None, None, None else: pred_pad, pred_mask = [None] * self.num_spkrs, [None] * self.num_spkrs loss_att, acc = [None] * self.num_spkrs, [None] * self.num_spkrs for i in range(self.num_spkrs): ( pred_pad[i], pred_mask[i], loss_att[i], acc[i], ) = self.decoder_and_attention( hs_pad[i], hs_mask[i], ys_pad[i], batch_size ) # 4. compute attention loss # The following is just an approximation loss_att = sum(map(lambda x: x[0] * x[1], zip(loss_att, ys_out_len))) / sum( ys_out_len ) self.acc = sum(map(lambda x: x[0] * x[1], zip(acc, ys_out_len))) / sum( ys_out_len ) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
[docs] def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size): """Forward decoder and attention loss.""" # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) return pred_pad, pred_mask, loss_att, acc
[docs] def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output
[docs] def recog(self, enc_output, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech of each speaker. :param ndnarray enc_output: encoder outputs (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] traced_decoder = None for i in range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recog(enc_output, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps
[docs] def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech of each speaker. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ # Encoder enc_output = self.encode(x) # Decoder nbest_hyps = [] for enc_out in enc_output: nbest_hyps.append( self.recog(enc_out, recog_args, char_list, rnnlm, use_jit) ) return nbest_hyps