Source code for espnet2.tasks.asvspoof

import argparse
from typing import Callable, Collection, Dict, List, Optional, Tuple

import numpy as np
import torch
from typeguard import typechecked

from espnet2.asr.encoder.abs_encoder import AbsEncoder

# TODO(checkpoint1): import conformer class class
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.asr.frontend.fused import FusedFrontends
from espnet2.asr.frontend.s3prl import S3prlFrontend
from espnet2.asr.frontend.windowing import SlidingWindow
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.preencoder.linear import LinearProjection
from espnet2.asr.preencoder.sinc import LightweightSincConvs
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.asr.specaug.specaug import SpecAug
from espnet2.asvspoof.decoder.abs_decoder import AbsDecoder
from espnet2.asvspoof.decoder.linear_decoder import LinearDecoder
from espnet2.asvspoof.espnet_model import ESPnetASVSpoofModel
from espnet2.asvspoof.loss.abs_loss import AbsASVSpoofLoss
from espnet2.asvspoof.loss.am_softmax_loss import ASVSpoofAMSoftmaxLoss
from espnet2.asvspoof.loss.binary_loss import ASVSpoofBinaryLoss
from espnet2.asvspoof.loss.oc_softmax_loss import ASVSpoofOCSoftmaxLoss
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.global_mvn import GlobalMVN
from espnet2.layers.utterance_mvn import UtteranceMVN
from espnet2.tasks.abs_task import AbsTask
from espnet2.torch_utils.initialize import initialize
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import CommonPreprocessor
from espnet2.train.trainer import Trainer
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import int_or_none, str2bool, str_or_none

frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(
        default=DefaultFrontend,
        sliding_window=SlidingWindow,
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
    ),
    type_check=AbsFrontend,
    default="default",
)
specaug_choices = ClassChoices(
    name="specaug",
    classes=dict(
        specaug=SpecAug,
    ),
    type_check=AbsSpecAug,
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default="utterance_mvn",
    optional=True,
)
preencoder_choices = ClassChoices(
    name="preencoder",
    classes=dict(
        sinc=LightweightSincConvs,
        linear=LinearProjection,
    ),
    type_check=AbsPreEncoder,
    default=None,
    optional=True,
)
encoder_choices = ClassChoices(
    "encoder",
    classes=dict(
        # TODO(checkpoint2): add conformer option in encoder
        transformer=TransformerEncoder,
    ),
    type_check=AbsEncoder,
    default="transformer",
)
decoder_choices = ClassChoices(
    "decoder",
    classes=dict(
        linear=LinearDecoder,
    ),
    type_check=AbsDecoder,
    default="linear",
)
losses_choices = ClassChoices(
    name="losses",
    classes=dict(
        binary_loss=ASVSpoofBinaryLoss,
        am_softmax_loss=ASVSpoofAMSoftmaxLoss,
        oc_softmax_loss=ASVSpoofOCSoftmaxLoss,
    ),
    type_check=AbsASVSpoofLoss,
    default=None,
)


[docs]class ASVSpoofTask(AbsTask): # If you need more than one optimizers, change this value num_optimizers: int = 1 # Add variable objects configurations class_choices_list = [ # --frontend and --frontend_conf frontend_choices, # --specaug and --specaug_conf specaug_choices, # --normalize and --normalize_conf normalize_choices, # --preencoder and --preencoder_conf preencoder_choices, # --encoder and --encoder_conf encoder_choices, # --decoder and --decoder_conf decoder_choices, ] # If you need to modify train() or eval() procedures, change Trainer class here trainer = Trainer
[docs] @classmethod def add_task_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group(description="Task related") # NOTE(kamo): add_arguments(..., required=True) can't be used # to provide --print_config mode. Instead of it, do as group.add_argument( "--init", type=lambda x: str_or_none(x.lower()), default=None, help="The initialization method", choices=[ "chainer", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", None, ], ) group.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) group = parser.add_argument_group(description="Preprocess related") group.add_argument( "--use_preprocessor", type=str2bool, default=True, help="Apply preprocessing to data or not", ) group.add_argument( "--losses", action=NestedDictAction, default=[ { "name": "sigmoid_loss", "conf": {}, }, ], help="The criterions binded with the loss wrappers.", ) for class_choices in cls.class_choices_list: # Append --<name> and --<name>_conf. # e.g. --encoder and --encoder_conf class_choices.add_arguments(group)
[docs] @classmethod @typechecked def build_collate_fn(cls, args: argparse.Namespace, train: bool) -> Callable[ [Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[List[str], Dict[str, torch.Tensor]], ]: # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
[docs] @classmethod @typechecked def build_preprocess_fn( cls, args: argparse.Namespace, train: bool ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: if args.use_preprocessor: retval = CommonPreprocessor( train=train, ) else: retval = None return retval
[docs] @classmethod def required_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: if not inference: retval = ("speech", "label") else: # Recognition mode retval = ("speech",) return retval
[docs] @classmethod def optional_data_names( cls, train: bool = True, inference: bool = False ) -> Tuple[str, ...]: retval = () return retval
[docs] @classmethod @typechecked def build_model(cls, args: argparse.Namespace) -> ESPnetASVSpoofModel: # 1. frontend if args.input_size is None: # Extract features in the model frontend_class = frontend_choices.get_class(args.frontend) frontend = frontend_class(**args.frontend_conf) input_size = frontend.output_size() else: # Give features from data-loader args.frontend = None args.frontend_conf = {} frontend = None input_size = args.input_size # 2. Data augmentation for spectrogram if args.specaug is not None: specaug_class = specaug_choices.get_class(args.specaug) specaug = specaug_class(**args.specaug_conf) else: specaug = None # 3. Normalization layer if args.normalize is not None: normalize_class = normalize_choices.get_class(args.normalize) normalize = normalize_class(**args.normalize_conf) else: normalize = None # 4. Pre-encoder input block # NOTE(kan-bayashi): Use getattr to keep the compatibility if getattr(args, "preencoder", None) is not None: preencoder_class = preencoder_choices.get_class(args.preencoder) preencoder = preencoder_class(**args.preencoder_conf) input_size = preencoder.output_size() else: preencoder = None # 4. Encoder encoder_class = encoder_choices.get_class(args.encoder) encoder = encoder_class(input_size=input_size, **args.encoder_conf) encoder_output_size = encoder.output_size() # 5. Decoder decoder_class = decoder_choices.get_class(args.decoder) decoder = decoder_class( encoder_output_size=encoder_output_size, **args.decoder_conf, ) # 6. Loss definition losses = {} if getattr(args, "losses", None) is not None: # This check is for the compatibility when load models # that packed by older version for ctr in args.losses: if "softmax" in ctr["name"]: loss = losses_choices.get_class(ctr["name"])( enc_dim=encoder_output_size, **ctr["conf"] ) else: loss = losses_choices.get_class(ctr["name"])(**ctr["conf"]) losses[ctr["name"]] = loss # 7. Build model model = ESPnetASVSpoofModel( frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, decoder=decoder, losses=losses, ) # 8. Initialize if args.init is not None: initialize(model, args.init) return model