Source code for espnet2.bin.spk_embed_extract

#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from glob import glob

import humanfriendly
import numpy as np
import torch
from torch.multiprocessing.spawn import ProcessContext

from espnet2.samplers.build_batch_sampler import BATCH_TYPES
from espnet2.tasks.spk import SpeakerTask
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
from espnet2.train.distributed_utils import (
from espnet2.train.reporter import Reporter
from espnet2.utils import config_argparse
from espnet2.utils.build_dataclass import build_dataclass
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import (
from espnet.utils.cli_utils import get_commandline_args

[docs]def extract_embed(args): distributed_option = build_dataclass(DistributedOption, args) distributed_option.init_options() if not distributed_option.distributed or distributed_option.dist_rank == 0: if not distributed_option.distributed: _rank = "" else: _rank = ( f":{distributed_option.dist_rank}/" f"{distributed_option.dist_world_size}" ) logging.basicConfig( level=args.log_level, format=f"[{os.uname()[1].split('.')[0]}{_rank}]" f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: # Suppress logging if RANK != 0 logging.basicConfig( level="ERROR", format=f"[{os.uname()[1].split('.')[0]}" f":{distributed_option.dist_rank}/{distributed_option.dist_world_size}]" f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) # Invoking torch.distributed.init_process_group distributed_option.init_torch_distributed() logging.basicConfig( level=args.log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if args.ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(args.seed) # 2. define train args spk_model, spk_train_args = SpeakerTask.build_model_from_file( args.spk_train_config, args.spk_model_file, device ) # 3. Overwrite args with inference args args = vars(args) args["valid_data_path_and_name_and_type"] = args["data_path_and_name_and_type"] args["valid_shape_file"] = args["shape_file"] args["preprocessor_conf"] = { "target_duration": args["target_duration"], "num_eval": args["num_eval"], "noise_apply_prob": 0.0, "rir_apply_prob": 0.0, } merged_args = vars(spk_train_args) merged_args.update(args) args = argparse.Namespace(**merged_args) # 4. Build data-iterator # NOTE(jeeweon): Temporarily disable distributed to let loader include all trials org_distributed = distributed_option.distributed distributed_option.distributed = False iterator = SpeakerTask.build_iter_factory( args=args, distributed_option=distributed_option, mode="valid", ) distributed_option.distributed = org_distributed loader = iterator.build_iter(0) bs = ( args.valid_batch_size // args.ngpu if distributed_option.distributed else args.valid_batch_size ) trainer_options = SpeakerTask.trainer.build_options(args) reporter = Reporter() # 5. Run inference for EER and minDCF calculation with reporter.observe("valid") as sub_reporter: SpeakerTask.trainer.extract_embed( model=spk_model, iterator=loader, reporter=sub_reporter, options=trainer_options, distributed_option=distributed_option, output_dir=args.output_dir, custom_bs=bs, average=args.average_embd, ) if distributed_option.distributed: torch.distributed.barrier() if not distributed_option.distributed or distributed_option.dist_rank == 0: # Combine dictionaries into one npzs = glob(args.output_dir + "/embeddings*.npz")"{npzs}") embd_dic = {} for npz in npzs: tmp_dic = dict(np.load(npz)) embd_dic.update(tmp_dic) set_name = args.data_path_and_name_and_type[0][0].split("/")[-2] np.savez(args.output_dir + f"/{set_name}_embeddings", **embd_dic) for npz in npzs: os.remove(npz)
[docs]def get_parser(): parser = config_argparse.ArgumentParser( description="speaker embedding extraction", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--log_level", type=lambda x: x.upper(), default="INFO", choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), help="The verbose level of logging", ) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument( "--ngpu", type=int, default=0, help="The number of gpus. 0 indicates CPU mode", ) parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument( "--dtype", default="float32", choices=["float16", "float32", "float64"], help="Data type", ) parser.add_argument( "--num_workers", type=int, default=1, help="The number of workers used for DataLoader", ) group = parser.add_argument_group("Input data related") group.add_argument( "--data_path_and_name_and_type", type=str2triple_str, required=True, action="append", ) _batch_type_help = "" for key, value in BATCH_TYPES.items(): _batch_type_help += f'"{key}":\n{value}\n' group.add_argument( "--batch_type", type=str, default="folded", choices=list(BATCH_TYPES), help=_batch_type_help, ) group.add_argument( "--batch_bins", type=int, default=1000000, help="The number of batch bins. Used if batch_type='length' or 'numel'", ) group.add_argument( "--valid_batch_bins", type=int_or_none, default=None, help="If not given, the value of --batch_bins is used", ) group.add_argument( "--valid_batch_type", type=str_or_none, default=None, choices=list(BATCH_TYPES) + [None], help="If not given, the value of --batch_type is used", ) group.add_argument( "--max_cache_size", type=humanfriendly.parse_size, default=0.0, help="The maximum cache size for data loader. e.g. 10MB, 20GB.", ) group.add_argument( "--max_cache_fd", type=int, default=32, help="The maximum number of file descriptors to be kept " "as opened for ark files. " "This feature is only valid when data type is 'kaldi_ark'.", ) group.add_argument( "--allow_multi_rates", type=str2bool, default=False, help="Whether to allow audios to have different sampling rates", ) group.add_argument( "--valid_max_cache_size", type=humanfriendly_parse_size_or_none, default=None, help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. " "If None, the 5 percent size of --max_cache_size", ) group.add_argument("--shape_file", type=str, action="append", default=[]) group.add_argument( "--input_size", type=int_or_none, default=None, help="The number of input dimension of the feature", ) group.add_argument( "--num_cohort_spk", type=int, default=5994, help="The number of cohort speakers in score norm", ) group.add_argument( "--num_utt_per_spk", type=int, default=10, help="The number of utterances per speaker in score norm", ) group.add_argument( "--utt_select_sec", type=int, default=8, help="Minimum duration for including the utt in cohort set in score norm", ) group.add_argument( "--average_spk", type=str2bool, default=False, help="whether to average cohort embeds per speaker in score norm", ) group.add_argument( "--adaptive_cohort_size", type=int, default=400, help="top-k cohort size in score norm", ) group.add_argument( "--qmf_dur_thresh", type=int, default=6, help="threshold of duration to be considered as long in qmf trainset", ) group.add_argument( "--qmf_num_trial_per_condition", type=int, default=5000, help="number of trials per condition in qmf trainset", ) group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) group.add_argument("--average_embd", type=str2bool, default=False) group.add_argument( "--train_dtype", default="float32", choices=["float16", "float32", "float64"], help="Data type for training.", ) group.add_argument( "--use_amp", type=str2bool, default=False, help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", ) group.add_argument( "--no_forward_run", type=str2bool, default=False, help="Just only iterating data loading without " "model forwarding and training", ) group.add_argument( "--sort_in_batch", type=str, default="descending", choices=["descending", "ascending"], help="Sort the samples in each mini-batches by the sample " 'lengths. To enable this, "shape_file" must have the length information.', ) group.add_argument( "--sort_batch", type=str, default="descending", choices=["descending", "ascending"], help="Sort mini-batches by the sample lengths", ) group.add_argument( "--drop_last_iter", type=str2bool, default=False, help="Exclude the minibatch with leftovers.", ) group = parser.add_argument_group("The model configuration related") group.add_argument( "--spk_train_config", type=str, help="SPK training configuration", ) group.add_argument( "--spk_model_file", type=str, help="SPK model parameter file", ) group.add_argument( "--model_tag", type=str, help="Pretrained model tag. If specify this option, *_train_config and " "*_file will be overwritten", ) group = parser.add_argument_group("distributed training related") group.add_argument( "--dist_backend", default="nccl", type=str, help="distributed backend", ) group.add_argument( "--dist_init_method", type=str, default="env://", help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' '"WORLD_SIZE", and "RANK" are referred.', ) group.add_argument( "--dist_world_size", default=None, type=int_or_none, help="number of nodes for distributed training", ) group.add_argument( "--dist_rank", type=int_or_none, default=None, help="node rank for distributed training", ) group.add_argument( # Not starting with "dist_" for compatibility to "--local_rank", type=int_or_none, default=None, help="local rank for distributed training. This option is used if " "--multiprocessing_distributed=false", ) group.add_argument( "--dist_master_addr", default=None, type=str_or_none, help="The master address for distributed training. " "This value is used when dist_init_method == 'env://'", ) group.add_argument( "--dist_master_port", default=None, type=int_or_none, help="The master port for distributed training" "This value is used when dist_init_method == 'env://'", ) group.add_argument( "--dist_launcher", default=None, type=str_or_none, choices=["slurm", "mpi", None], help="The launcher type for distributed training", ) group.add_argument( "--multiprocessing_distributed", default=False, type=str2bool, help="Use multi-processing distributed training to launch " "N processes per node, which has N GPUs. This is the " "fastest way to use PyTorch for either single node or " "multi node data parallel training", ) group.add_argument( "--unused_parameters", type=str2bool, default=False, help="Whether to use the find_unused_parameters in " "torch.nn.parallel.DistributedDataParallel ", ) group.add_argument( "--sharded_ddp", default=False, type=str2bool, help="Enable sharded training provided by fairscale", ) group = parser.add_argument_group("trainer initialization related") group.add_argument( "--use_matplotlib", type=str2bool, default=True, help="Enable matplotlib logging", ) group.add_argument( "--use_tensorboard", type=str2bool, default=True, help="Enable tensorboard logging", ) group.add_argument( "--create_graph_in_tensorboard", type=str2bool, default=False, help="Whether to create graph in tensorboard", ) group.add_argument( "--use_wandb", type=str2bool, default=False, help="Enable wandb logging", ) group.add_argument( "--wandb_project", type=str, default=None, help="Specify wandb project", ) group.add_argument( "--wandb_id", type=str, default=None, help="Specify wandb id", ) group.add_argument( "--wandb_entity", type=str, default=None, help="Specify wandb entity", ) group.add_argument( "--wandb_name", type=str, default=None, help="Specify wandb run name", ) group.add_argument( "--wandb_model_log_interval", type=int, default=-1, help="Set the model log period", ) group.add_argument( "--detect_anomaly", type=str2bool, default=False, help="Set torch.autograd.set_detect_anomaly", ) group.add_argument( "--use_lora", type=str2bool, default=False, help="Enable LoRA based finetuning, see ( " "for large pre-trained foundation models, like Whisper", ) group.add_argument( "--save_lora_only", type=str2bool, default=True, help="Only save LoRA parameters or save all model parameters", ) group.add_argument( "--lora_conf", action=NestedDictAction, default=dict(), help="Configuration for LoRA based finetuning", ) group = parser.add_argument_group("cudnn mode related") group.add_argument( "--cudnn_enabled", type=str2bool, default=torch.backends.cudnn.enabled, help="Enable CUDNN", ) group.add_argument( "--cudnn_benchmark", type=str2bool, default=torch.backends.cudnn.benchmark, help="Enable cudnn-benchmark mode", ) group.add_argument( "--cudnn_deterministic", type=str2bool, default=True, help="Enable cudnn-deterministic mode", ) group = parser.add_argument_group("The inference hyperparameter related") group.add_argument( "--valid_batch_size", type=int, default=1, help="The batch size for inference", ) group.add_argument( "--target_duration", type=float, default=3.0, help="Duration (in seconds) of samples in a minibatch", ) group.add_argument( "--num_eval", type=int, default=10, help="Number of segments to make from one utterance in the inference phase", ) group.add_argument("--fold_length", type=int, action="append", default=[]) group.add_argument( "--use_preprocessor", type=str2bool, default=True, help="Apply preprocessing to data or not", ) return parser
[docs]def main(cmd=None): print(get_commandline_args(), file=sys.stderr) parser = get_parser() args = parser.parse_args(cmd) # "distributed" is decided using the other command args resolve_distributed_mode(args) if not args.distributed or not args.multiprocessing_distributed: extract_embed(args) else: assert args.ngpu > 1, args.ngpu # Multi-processing distributed mode: e.g. 2node-4process-4GPU # | Host1 | Host2 | # | Process1 | Process2 | <= Spawn processes # |Child1|Child2|Child1|Child2| # |GPU1 |GPU2 |GPU1 |GPU2 | # See also the following usage of --multiprocessing-distributed: # num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher) if num_nodes == 1: args.dist_master_addr = "localhost" args.dist_rank = 0 # Single node distributed training with multi-GPUs if ( args.dist_init_method == "env://" and get_master_port(args.dist_master_port) is None ): # Get the unused port args.dist_master_port = free_port() # Assume that nodes use same number of GPUs each other args.dist_world_size = args.ngpu * num_nodes node_rank = get_node_rank(args.dist_rank, args.dist_launcher) # The following block is copied from: # error_queues = [] processes = [] mp = torch.multiprocessing.get_context("spawn") for i in range(args.ngpu): # Copy args local_args = argparse.Namespace(**vars(args)) local_args.local_rank = i local_args.dist_rank = args.ngpu * node_rank + i local_args.ngpu = 1 process = mp.Process( target=extract_embed, args=(local_args,), daemon=False, ) process.start() processes.append(process) error_queues.append(mp.SimpleQueue()) # Loop on join until it returns True or raises an exception. while not ProcessContext(processes, error_queues).join(): pass
if __name__ == "__main__": main()