Source code for espnet2.tasks.enh_s2t

import argparse
import copy
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple

import numpy as np
import torch
from typeguard import typechecked

from espnet2.asr.ctc import CTC
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.diar.espnet_model import ESPnetDiarizationModel
from espnet2.enh.espnet_enh_s2t_model import ESPnetEnhS2TModel
from espnet2.enh.espnet_model import ESPnetEnhancementModel
from espnet2.tasks.abs_task import AbsTask
from espnet2.tasks.asr import ASRTask
from espnet2.tasks.asr import decoder_choices as asr_decoder_choices_
from espnet2.tasks.asr import encoder_choices as asr_encoder_choices_
from espnet2.tasks.asr import frontend_choices, normalize_choices
from espnet2.tasks.asr import postencoder_choices as asr_postencoder_choices_
from espnet2.tasks.asr import preencoder_choices as asr_preencoder_choices_
from espnet2.tasks.asr import specaug_choices
from espnet2.tasks.diar import DiarizationTask
from espnet2.tasks.diar import attractor_choices as diar_attractor_choices_
from espnet2.tasks.diar import decoder_choices as diar_decoder_choices_
from espnet2.tasks.diar import encoder_choices as diar_encoder_choices_
from espnet2.tasks.diar import frontend_choices as diar_front_end_choices_
from espnet2.tasks.diar import label_aggregator_choices
from espnet2.tasks.diar import normalize_choices as diar_normalize_choices_
from espnet2.tasks.diar import specaug_choices as diar_specaug_choices_
from espnet2.tasks.enh import EnhancementTask
from espnet2.tasks.enh import decoder_choices as enh_decoder_choices_
from espnet2.tasks.enh import encoder_choices as enh_encoder_choices_
from espnet2.tasks.enh import mask_module_choices as enh_mask_module_choices_
from espnet2.tasks.enh import separator_choices as enh_separator_choices_
from espnet2.tasks.st import STTask
from espnet2.tasks.st import decoder_choices as st_decoder_choices_
from espnet2.tasks.st import encoder_choices as st_encoder_choices_
from espnet2.tasks.st import extra_asr_decoder_choices as st_extra_asr_decoder_choices_
from espnet2.tasks.st import extra_mt_decoder_choices as st_extra_mt_decoder_choices_
from espnet2.tasks.st import postencoder_choices as st_postencoder_choices_
from espnet2.tasks.st import preencoder_choices as st_preencoder_choices_
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.torch_utils.initialize import initialize
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import (
    CommonPreprocessor,
    CommonPreprocessor_multi,
    MutliTokenizerCommonPreprocessor,
)
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import int_or_none, str2bool, str_or_none

# Enhancement
enh_encoder_choices = copy.deepcopy(enh_encoder_choices_)
enh_encoder_choices.name = "enh_encoder"
enh_decoder_choices = copy.deepcopy(enh_decoder_choices_)
enh_decoder_choices.name = "enh_decoder"
enh_separator_choices = copy.deepcopy(enh_separator_choices_)
enh_separator_choices.name = "enh_separator"
enh_mask_module_choices = copy.deepcopy(enh_mask_module_choices_)
enh_mask_module_choices.name = "enh_mask_module"

# ASR (also SLU)
asr_preencoder_choices = copy.deepcopy(asr_preencoder_choices_)
asr_preencoder_choices.name = "asr_preencoder"
asr_encoder_choices = copy.deepcopy(asr_encoder_choices_)
asr_encoder_choices.name = "asr_encoder"
asr_postencoder_choices = copy.deepcopy(asr_postencoder_choices_)
asr_postencoder_choices.name = "asr_postencoder"
asr_decoder_choices = copy.deepcopy(asr_decoder_choices_)
asr_decoder_choices.name = "asr_decoder"

# ST
st_preencoder_choices = copy.deepcopy(st_preencoder_choices_)
st_preencoder_choices.name = "st_preencoder"
st_encoder_choices = copy.deepcopy(st_encoder_choices_)
st_encoder_choices.name = "st_encoder"
st_postencoder_choices = copy.deepcopy(st_postencoder_choices_)
st_postencoder_choices.name = "st_postencoder"
st_decoder_choices = copy.deepcopy(st_decoder_choices_)
st_decoder_choices.name = "st_decoder"
st_extra_asr_decoder_choices = copy.deepcopy(st_extra_asr_decoder_choices_)
st_extra_asr_decoder_choices.name = "st_extra_asr_decoder"
st_extra_mt_decoder_choices = copy.deepcopy(st_extra_mt_decoder_choices_)
st_extra_mt_decoder_choices.name = "st_extra_mt_decoder"

# DIAR
diar_frontend_choices = copy.deepcopy(diar_front_end_choices_)
diar_frontend_choices.name = "diar_frontend"
diar_specaug_choices = copy.deepcopy(diar_specaug_choices_)
diar_specaug_choices.name = "diar_specaug"
diar_normalize_choices = copy.deepcopy(diar_normalize_choices_)
diar_normalize_choices.name = "diar_normalize"
diar_encoder_choices = copy.deepcopy(diar_encoder_choices_)
diar_encoder_choices.name = "diar_encoder"
diar_decoder_choices = copy.deepcopy(diar_decoder_choices_)
diar_decoder_choices.name = "diar_decoder"
diar_attractor_choices = copy.deepcopy(diar_attractor_choices_)
diar_attractor_choices.name = "diar_attractor"


MAX_REFERENCE_NUM = 100

name2task = dict(
    enh=EnhancementTask,
    asr=ASRTask,
    st=STTask,
    diar=DiarizationTask,
)

# More can be added to the following attributes
enh_attributes = [
    "encoder",
    "encoder_conf",
    "separator",
    "separator_conf",
    "mask_module",
    "mask_module_conf",
    "decoder",
    "decoder_conf",
    "criterions",
]

asr_attributes = [
    "token_list",
    "input_size",
    "frontend",
    "frontend_conf",
    "specaug",
    "specaug_conf",
    "normalize",
    "normalize_conf",
    "preencoder",
    "preencoder_conf",
    "encoder",
    "encoder_conf",
    "postencoder",
    "postencoder_conf",
    "decoder",
    "decoder_conf",
    "ctc_conf",
]

st_attributes = [
    "token_list",
    "src_token_list",
    "input_size",
    "frontend",
    "frontend_conf",
    "specaug",
    "specaug_conf",
    "normalize",
    "normalize_conf",
    "preencoder",
    "preencoder_conf",
    "encoder",
    "encoder_conf",
    "postencoder",
    "postencoder_conf",
    "decoder",
    "decoder_conf",
    "ctc_conf",
    "extra_asr_decoder",
    "extra_asr_decoder_conf",
    "extra_mt_decoder",
    "extra_mt_decoder_conf",
]

diar_attributes = [
    "input_size",
    "num_spk",
    "frontend",
    "frontend_conf",
    "specaug",
    "specaug_conf",
    "normalize",
    "normalize_conf",
    "encoder",
    "encoder_conf",
    "decoder",
    "decoder_conf",
    "attractor",
    "attractor_conf",
    "label_aggregator",
    "label_aggregator_conf",
]


[docs]class EnhS2TTask(AbsTask): # If you need more than one optimizers, change this value num_optimizers: int = 1 # Add variable objects configurations class_choices_list = [ # --enh_encoder and --enh_encoder_conf enh_encoder_choices, # --enh_separator and --enh_separator_conf enh_separator_choices, # --enh_decoder and --enh_decoder_conf enh_decoder_choices, # --enh_mask_module and --enh_mask_module_conf enh_mask_module_choices, # --frontend and --frontend_conf frontend_choices, # --specaug and --specaug_conf specaug_choices, # --normalize and --normalize_conf normalize_choices, # --asr_preencoder and --asr_preencoder_conf asr_preencoder_choices, # --asr_encoder and --asr_encoder_conf asr_encoder_choices, # --asr_postencoder and --asr_postencoder_conf asr_postencoder_choices, # --asr_decoder and --asr_decoder_conf asr_decoder_choices, # --st_preencoder and --st_preencoder_conf st_preencoder_choices, # --st_encoder and --st_encoder_conf st_encoder_choices, # --st_postencoder and --st_postencoder_conf st_postencoder_choices, # --st_decoder and --st_decoder_conf st_decoder_choices, # --st_extra_asr_decoder and --st_extra_asr_decoder_conf st_extra_asr_decoder_choices, # --st_extra_mt_decoder and --st_extra_mt_decoder_conf st_extra_mt_decoder_choices, # --diar_frontend and --diar_frontend_conf diar_frontend_choices, # --diar_specaug and --diar_specaug_conf diar_specaug_choices, # --diar_normalize and --diar_normalize_conf diar_normalize_choices, # --diar_encoder and --diar_encoder_conf diar_encoder_choices, # --diar_decoder and --diar_decoder_conf diar_decoder_choices, # --label_aggregator and --label_aggregator_conf label_aggregator_choices, # --diar_attractor and --diar_attractor_conf diar_attractor_choices, ] # If you need to modify train() or eval() procedures, change Trainer class here trainer = Trainer
[docs] @classmethod def add_task_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group(description="Task related") group.add_argument( "--token_list", type=str_or_none, default=None, help="A text mapping int-id to token", ) group.add_argument( "--src_token_list", type=str_or_none, default=None, help="A text mapping int-id to token (for source language)", ) group.add_argument( "--init", type=lambda x: str_or_none(x.lower()), default=None, help="The initialization method", choices=[ "chainer", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", None, ], ) group.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) group.add_argument( "--ctc_conf", action=NestedDictAction, default=get_default_kwargs(CTC), help="The keyword arguments for CTC class.", ) group.add_argument( "--enh_criterions", action=NestedDictAction, default=[ { "name": "si_snr", "conf": {}, "wrapper": "fixed_order", "wrapper_conf": {}, }, ], help="The criterions binded with the loss wrappers.", ) group.add_argument( "--diar_num_spk", type=int_or_none, default=None, help="The number of speakers (for each recording) for diar submodel class", ) group.add_argument( "--diar_input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) group.add_argument( "--enh_model_conf", action=NestedDictAction, default=get_default_kwargs(ESPnetEnhancementModel), help="The keyword arguments for enh submodel class.", ) group.add_argument( "--asr_model_conf", action=NestedDictAction, default=get_default_kwargs(ESPnetASRModel), help="The keyword arguments for asr submodel class.", ) group.add_argument( "--st_model_conf", action=NestedDictAction, default=get_default_kwargs(ESPnetEnhancementModel), help="The keyword arguments for st submodel class.", ) group.add_argument( "--diar_model_conf", action=NestedDictAction, default=get_default_kwargs(ESPnetDiarizationModel), help="The keyword arguments for diar submodel class.", ) group.add_argument( "--subtask_series", type=str, nargs="+", default=("enh", "asr"), choices=["enh", "asr", "st", "diar"], help="The series of subtasks in the pipeline.", ) group.add_argument( "--model_conf", action=NestedDictAction, default=get_default_kwargs(ESPnetEnhS2TModel), help="The keyword arguments for model class.", ) group = parser.add_argument_group(description="Preprocess related") group.add_argument( "--use_preprocessor", type=str2bool, default=False, help="Apply preprocessing to data or not", ) group.add_argument( "--token_type", type=str, default="bpe", choices=["bpe", "char", "word", "phn"], help="The text will be tokenized " "in the specified level token", ) group.add_argument( "--bpemodel", type=str_or_none, default=None, help="The model file of sentencepiece", ) group.add_argument( "--src_token_type", type=str, default="bpe", choices=["bpe", "char", "word", "phn"], help="The source text will be tokenized " "in the specified level token", ) group.add_argument( "--src_bpemodel", type=str_or_none, default=None, help="The model file of sentencepiece (for source language)", ) group.add_argument( "--non_linguistic_symbols", type=str_or_none, help="non_linguistic_symbols file path", ) group.add_argument( "--cleaner", type=str_or_none, choices=[None, "tacotron", "jaconv", "vietnamese"], default=None, help="Apply text cleaning", ) group.add_argument( "--g2p", type=str_or_none, choices=g2p_choices, default=None, help="Specify g2p method if --token_type=phn", ) group.add_argument( "--text_name", nargs="+", default=["text"], type=str, help="Specify the text_name attribute used in the preprocessor", ) for class_choices in cls.class_choices_list: # Append --<name> and --<name>_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(group)
[docs] @classmethod @typechecked def build_collate_fn(cls, args: argparse.Namespace, train: bool) -> Callable[ [Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[List[str], Dict[str, torch.Tensor]], ]: # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
[docs] @classmethod @typechecked def build_preprocess_fn( cls, args: argparse.Namespace, train: bool ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: if args.use_preprocessor: if "st" in args.subtask_series: retval = MutliTokenizerCommonPreprocessor( train=train, token_type=[args.token_type, args.src_token_type], token_list=[args.token_list, args.src_token_list], bpemodel=[args.bpemodel, args.src_bpemodel], non_linguistic_symbols=args.non_linguistic_symbols, text_cleaner=args.cleaner, g2p_type=args.g2p, # NOTE(kamo): Check attribute existence for backward compatibility rir_scp=getattr(args, "rir_scp", None), rir_apply_prob=getattr(args, "rir_apply_prob", 1.0), noise_scp=getattr(args, "noise_scp", None), noise_apply_prob=getattr(args, "noise_apply_prob", 1.0), noise_db_range=getattr(args, "noise_db_range", "13_15"), short_noise_thres=getattr(args, "short_noise_thres", 0.5), speech_volume_normalize=getattr( args, "speech_volume_normalize", None ), speech_name="speech", text_name=["text", "src_text"], **getattr(args, "preprocessor_conf", {}), ) elif "diar" in args.subtask_series: retval = CommonPreprocessor( train=train, **getattr(args, "preprocessor_conf", {}) ) else: retval = CommonPreprocessor_multi( train=train, token_type=args.token_type, token_list=args.token_list, bpemodel=args.bpemodel, non_linguistic_symbols=args.non_linguistic_symbols, text_name=getattr(args, "text_name", ["text"]), text_cleaner=args.cleaner, g2p_type=args.g2p, **getattr(args, "preprocessor_conf", {}), ) else: retval = None return retval
[docs] @classmethod def required_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: retval = ("speech", "speech_ref1") else: # Recognition mode retval = ("speech",) return retval
[docs] @classmethod def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: retval = ["text", "dereverb_ref1"] st = 2 if "speech_ref1" in retval else 1 retval += ["speech_ref{}".format(n) for n in range(st, MAX_REFERENCE_NUM + 1)] retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] retval += ["text_spk{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] retval += ["src_text"] retval = tuple(retval) return retval
[docs] @classmethod @typechecked def build_model(cls, args: argparse.Namespace) -> ESPnetEnhS2TModel: # Build submodels in the order of subtask_series model_conf = args.model_conf.copy() for _, subtask in enumerate(args.subtask_series): subtask_conf = dict( init=None, model_conf=eval(f"args.{subtask}_model_conf") ) for attr in eval(f"{subtask}_attributes"): subtask_conf[attr] = ( getattr(args, subtask + "_" + attr, None) if getattr(args, subtask + "_" + attr, None) is not None else getattr(args, attr, None) ) if subtask in ["asr", "st", "diar"]: m_subtask = "s2t" elif subtask in ["enh"]: m_subtask = subtask else: raise ValueError(f"{subtask} not supported.") logging.info(f"Building {subtask} task model, using config: {subtask_conf}") model_conf[f"{m_subtask}_model"] = name2task[subtask].build_model( argparse.Namespace(**subtask_conf) ) # 8. Build model model = ESPnetEnhS2TModel(**model_conf) # FIXME(kamo): Should be done in model? # 9. Initialize if args.init is not None: initialize(model, args.init) return model