Source code for espnet2.uasr.espnet_model

import argparse
import logging
from contextlib import contextmanager
from typing import Dict, Optional, Tuple

import editdistance
import torch
import torch.nn.functional as F
from packaging.version import parse as V
from typeguard import typechecked

from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.uasr.discriminator.abs_discriminator import AbsDiscriminator
from espnet2.uasr.generator.abs_generator import AbsGenerator
from espnet2.uasr.loss.abs_loss import AbsUASRLoss
from espnet2.uasr.segmenter.abs_segmenter import AbsSegmenter
from espnet2.utils.types import str2bool
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask

if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


try:
    import kenlm  # for CI import
except (ImportError, ModuleNotFoundError):
    kenlm = None


[docs]class ESPnetUASRModel(AbsESPnetModel): """Unsupervised ASR model. The source code is from FAIRSEQ: https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec/unsupervised """ @typechecked def __init__( self, frontend: Optional[AbsFrontend], segmenter: Optional[AbsSegmenter], generator: AbsGenerator, discriminator: AbsDiscriminator, losses: Dict[str, AbsUASRLoss], kenlm_path: Optional[str], token_list: Optional[list], max_epoch: Optional[int], vocab_size: int, cfg: Optional[Dict] = None, pad: int = 1, sil_token: str = "<SIL>", sos_token: str = "<s>", eos_token: str = "</s>", skip_softmax: str2bool = False, use_gumbel: str2bool = False, use_hard_gumbel: str2bool = True, min_temperature: float = 0.1, max_temperature: float = 2.0, decay_temperature: float = 0.99995, use_collected_training_feats: str2bool = False, ): super().__init__() # note that eos is the same as sos (equivalent ID) self.frontend = frontend self.segmenter = segmenter self.use_segmenter = True if segmenter is not None else False self.generator = generator self.discriminator = discriminator self.pad = pad if cfg is not None: cfg = argparse.Namespace(**cfg) self.skip_softmax = cfg.no_softmax self.use_gumbel = cfg.gumbel self.use_hard_gumbel = cfg.hard_gumbel else: self.skip_softmax = skip_softmax self.use_gumbel = use_gumbel self.use_hard_gumbel = use_hard_gumbel self.use_collected_training_feats = use_collected_training_feats self.min_temperature = min_temperature self.max_temperature = max_temperature self.decay_temperature = decay_temperature self.current_temperature = max_temperature self._number_updates = 0 self._number_epochs = 0 self.max_epoch = max_epoch # for loss registration self.losses = torch.nn.ModuleDict(losses) # for validation self.vocab_size = vocab_size self.token_list = token_list self.token_id_converter = TokenIDConverter(token_list=token_list) self.sil = self.token_id_converter.tokens2ids([sil_token])[0] self.sos = self.token_id_converter.tokens2ids([sos_token])[0] self.eos = self.token_id_converter.tokens2ids([eos_token])[0] self.kenlm = None assert ( kenlm is not None ), "kenlm is not installed, please install from tools/installers" if kenlm_path: self.kenlm = kenlm.Model(kenlm_path) @property def number_updates(self): return self._number_updates @number_updates.setter @typechecked def number_updates(self, iiter: int): assert iiter >= 0 self._number_updates = iiter
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: Optional[torch.Tensor] = None, text_lengths: Optional[torch.Tensor] = None, pseudo_labels: Optional[torch.Tensor] = None, pseudo_labels_lengths: Optional[torch.Tensor] = None, do_validation: Optional[str2bool] = False, print_hyp: Optional[str2bool] = False, **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Segmenter + Generator + Discriminator + Calc Loss Args: """ stats = {} assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), ( speech.shape, speech_lengths.shape, text.shape, text_lengths.shape, ) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Feats encode (Extract feats + Apply segmenter) feats, padding_mask = self.encode(speech, speech_lengths) # 2. Generate fake samples ( generated_sample, real_sample, x_inter, generated_sample_padding_mask, ) = self.generator(feats, text, padding_mask) # 3. Reprocess segments if self.use_segmenter: ( generated_sample, generated_sample_padding_mask, ) = self.segmenter.logit_segment( generated_sample, generated_sample_padding_mask ) # for phone_diversity_loss generated_sample_logits = generated_sample if not self.skip_softmax: if self.training and self.use_gumbel: generated_sample = F.gumbel_softmax( generated_sample_logits.float(), tau=self.curr_temp, hard=self.use_hard_gumbel, ).type_as(generated_sample_logits) else: generated_sample = generated_sample_logits.softmax(-1) # for validation vocab_seen = None if do_validation: batch_num_errors = 0 batched_hyp_ids = generated_sample.argmax(-1) batched_hyp_ids[generated_sample_padding_mask] = self.pad # for kenlm ppl metric batch_lm_log_prob = 0 batch_num_hyp_tokens = 0 vocab_seen = torch.zeros(self.vocab_size - 4, dtype=torch.bool) for hyp_ids, ref_ids in zip(batched_hyp_ids, text): # remove <pad> and <unk> hyp_ids = hyp_ids[hyp_ids >= 4] # remove duplicate tokens hyp_ids = hyp_ids.unique_consecutive() # remove silence hyp_ids_nosil = hyp_ids[hyp_ids != self.sil] hyp_ids_nosil_list = hyp_ids_nosil.tolist() if self.kenlm: hyp_token_list = self.token_id_converter.ids2tokens( integers=hyp_ids ) hyp_tokens = " ".join(hyp_token_list) lm_log_prob = self.kenlm.score(hyp_tokens) batch_lm_log_prob += lm_log_prob batch_num_hyp_tokens += len(hyp_token_list) hyp_tokens_index = hyp_ids[hyp_ids >= 4] vocab_seen[hyp_tokens_index - 4] = True ref_ids = ref_ids[ref_ids != self.pad] ref_ids_list = ref_ids.tolist() num_errors = editdistance.eval(hyp_ids_nosil_list, ref_ids_list) batch_num_errors += num_errors stats["batch_num_errors"] = batch_num_errors stats["batch_num_ref_tokens"] = text_lengths.sum().item() if self.kenlm: stats["batch_lm_log_prob"] = batch_lm_log_prob stats["batch_num_hyp_tokens"] = batch_num_hyp_tokens stats["batch_size"] = batch_size # print the last sample in the batch if print_hyp: hyp_token_list = self.token_id_converter.ids2tokens( integers=hyp_ids_nosil ) hyp_tokens = " ".join(hyp_token_list) ref_token_list = self.token_id_converter.ids2tokens(integers=ref_ids) ref_tokens = " ".join(ref_token_list) logging.info(f"[REF]: {ref_tokens}") logging.info(f"[HYP]: {hyp_tokens}") real_sample_padding_mask = text == self.pad # 5. Discriminator condition generated_sample_prediction = self.discriminator( generated_sample, generated_sample_padding_mask ) real_sample_prediction = self.discriminator( real_sample, real_sample_padding_mask ) is_discriminative_step = self.is_discriminative_step() # 5. Calculate losses loss_info = [] if "discriminator_loss" in self.losses.keys(): ( generated_sample_prediction_loss, real_sample_prediction_loss, ) = self.losses["discriminator_loss"]( generated_sample_prediction, real_sample_prediction, is_discriminative_step, ) loss_info.append( generated_sample_prediction_loss * self.losses["discriminator_loss"].weight ) if is_discriminative_step: loss_info.append( real_sample_prediction_loss * self.losses["discriminator_loss"].weight ) else: generated_sample_prediction_loss, real_sample_prediction_loss = None, None if "gradient_penalty" in self.losses.keys(): gp = self.losses["gradient_penalty"]( generated_sample, real_sample, self.training, is_discriminative_step, ) loss_info.append(gp * self.losses["gradient_penalty"].weight) loss_info.append(gp * self.losses["gradient_penalty"].weight) else: gp = None if "phoneme_diversity_loss" in self.losses.keys(): pdl = self.losses["phoneme_diversity_loss"]( generated_sample_logits, batch_size, is_discriminative_step ) loss_info.append(pdl * self.losses["phoneme_diversity_loss"].weight) else: pdl = None if "smoothness_penalty" in self.losses.keys(): sp = self.losses["smoothness_penalty"]( generated_sample_logits, generated_sample_padding_mask, batch_size, is_discriminative_step, ) loss_info.append(sp * self.losses["smoothness_penalty"].weight) else: sp = None if "pseudo_label_loss" in self.losses.keys() and pseudo_labels is not None: mmi = self.losses["pseudo_label_loss"]( x_inter, pseudo_labels, is_discriminative_step ) loss_info.append(mmi * self.losses["pseudo_label_loss"].weight) else: mmi = None # Update temperature self._change_temperature() self.number_updates += 1 loss = sum(loss_info) # Collect total loss stats stats["loss"] = loss.detach() stats["generated_sample_prediction_loss"] = generated_sample_prediction_loss stats["real_sample_prediction_loss"] = real_sample_prediction_loss stats["gp"] = gp stats["sp"] = sp stats["pdl"] = pdl stats["mmi"] = mmi # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight, vocab_seen
[docs] def inference( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ): # 1. Feats encode (Extract feats + Apply segmenter) feats, padding_mask = self.encode(speech, speech_lengths) # 2. Generate fake samples ( generated_sample, _, x_inter, generated_sample_padding_mask, ) = self.generator(feats, None, padding_mask) # generated_sample = generated_sample.softmax(-1) return generated_sample, generated_sample_padding_mask
[docs] def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: Optional[torch.Tensor] = None, text_lengths: Optional[torch.Tensor] = None, **kwargs, ) -> Dict[str, torch.Tensor]: if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) speech = F.layer_norm(speech, speech.shape) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return {"feats": feats, "feats_lengths": feats_lengths}
def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None and not self.use_collected_training_feats: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) speech = F.layer_norm(speech, speech.shape) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract (usually with pre-extracted feat) # logging.info("use exisitng features") feats, feats_lengths = speech, speech_lengths return feats, feats_lengths
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) padding_mask = make_pad_mask(feats_lengths).to(feats.device) # 2. Apply feats if self.use_segmenter: feats, padding_mask = self.segmenter.pre_segment(feats, padding_mask) return feats, padding_mask
[docs] def is_discriminative_step(self): return self.number_updates % 2 == 1
[docs] def get_optim_index(self): return self.number_updates % 2
def _change_temperature(self): self.current_temperature = max( self.max_temperature * self.decay_temperature**self.number_updates, self.min_temperature, )