Source code for espnet2.train.spk_trainer

"""
Trainer module for speaker recognition.
In speaker recognition (embedding extractor training/inference),
calculating validation loss in closed set is not informative since
generalization in unseen utterances from known speakers are good in most cases.
Thus, we measure open set equal error rate (EER) using unknown speakers by
overriding validate_one_epoch.
"""

from typing import Dict, Iterable

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim
from typeguard import typechecked

from espnet2.torch_utils.device_funcs import to_device
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.reporter import SubReporter
from espnet2.train.trainer import Trainer, TrainerOptions
from espnet2.utils.eer import ComputeErrorRates, ComputeMinDcf, tuneThresholdfromScore

if torch.distributed.is_available():
    from torch.distributed import ReduceOp


[docs]class SpkTrainer(Trainer): """Trainer designed for speaker recognition. Training will be done as closed set classification. Validation will be open set EER calculation. """ def __init__(self): raise RuntimeError("This class can't be instantiated.")
[docs] @classmethod @torch.no_grad() @typechecked def validate_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Dict[str, torch.Tensor]], reporter: SubReporter, options: TrainerOptions, distributed_option: DistributedOption, ) -> None: ngpu = options.ngpu distributed = distributed_option.distributed model.eval() scores = [] labels = [] spk_embd_dic = {} bs = 0 # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") # fill dictionary with speech samples utt_id_list = [] speech_list = [] task_token = None for utt_id, batch in iterator: bs = max(bs, len(utt_id)) if "task_tokens" in batch: task_token = batch["task_tokens"][0] assert isinstance(batch, dict), type(batch) for _utt_id, _speech, _speech2 in zip( utt_id, batch["speech"], batch["speech2"] ): _utt_id_1, _utt_id_2 = _utt_id.split("*") if _utt_id_1 not in utt_id_list: utt_id_list.append(_utt_id_1) speech_list.append( to_device(_speech, "cuda" if ngpu > 0 else "cpu") ) if _utt_id_2 not in utt_id_list: utt_id_list.append(_utt_id_2) speech_list.append( to_device(_speech2, "cuda" if ngpu > 0 else "cpu") ) # extract speaker embeddings. n_utt = len(utt_id_list) for ii in range(0, n_utt, bs): _utt_ids = utt_id_list[ii : ii + bs] _speechs = speech_list[ii : ii + bs] _speechs = torch.stack(_speechs, dim=0) org_shape = (_speechs.size(0), _speechs.size(1)) _speechs = _speechs.flatten(0, 1) _speechs = to_device(_speechs, "cuda" if ngpu > 0 else "cpu") if task_token is None: task_tokens = None else: task_tokens = to_device( task_token.repeat(_speechs.size(0)), "cuda" if ngpu > 0 else "cpu" ).unsqueeze(1) spk_embds = model( speech=_speechs, spk_labels=None, extract_embd=True, task_tokens=task_tokens, ) spk_embds = F.normalize(spk_embds, p=2, dim=1) spk_embds = spk_embds.view(org_shape[0], org_shape[1], -1) for _utt_id, _spk_embd in zip(_utt_ids, spk_embds): spk_embd_dic[_utt_id] = _spk_embd del utt_id_list del speech_list # calculate similarity scores for utt_id, batch in iterator: batch["spk_labels"] = to_device( batch["spk_labels"], "cuda" if ngpu > 0 else "cpu" ) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break for _utt_id in utt_id: _utt_id_1, _utt_id_2 = _utt_id.split("*") score = torch.cdist(spk_embd_dic[_utt_id_1], spk_embd_dic[_utt_id_2]) score = -1.0 * torch.mean(score) scores.append(score.view(1)) # 0-dim to 1-dim tensor for cat labels.append(batch["spk_labels"]) else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) torch.cuda.empty_cache() scores = torch.cat(scores).type(torch.float32) labels = torch.cat(labels).type(torch.int32).flatten() if distributed: # get the number of trials assigned on each GPU length = to_device( torch.tensor([labels.size(0)], dtype=torch.int32), "cuda" ) lengths_all = [ to_device(torch.zeros(1, dtype=torch.int32), "cuda") for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(lengths_all, length) scores_all = [ to_device(torch.zeros(i, dtype=torch.float32), "cuda") for i in lengths_all ] torch.distributed.all_gather(scores_all, scores) scores = torch.cat(scores_all) labels_all = [ to_device(torch.zeros(i, dtype=torch.int32), "cuda") for i in lengths_all ] torch.distributed.all_gather(labels_all, labels) labels = torch.cat(labels_all) # rank = torch.distributed.get_rank() torch.distributed.barrier() scores = scores.detach().cpu().numpy() labels = labels.detach().cpu().numpy() # calculate statistics in target and nontarget classes. n_trials = len(scores) scores_trg = [] scores_nontrg = [] for _s, _l in zip(scores, labels): if _l == 1: scores_trg.append(_s) elif _l == 0: scores_nontrg.append(_s) else: raise ValueError(f"{_l}, {type(_l)}") trg_mean = float(np.mean(scores_trg)) trg_std = float(np.std(scores_trg)) nontrg_mean = float(np.std(scores_nontrg)) nontrg_std = float(np.std(scores_nontrg)) # exception for collect_stats. if len(scores) == 1: reporter.register(stats=dict(eer=1.0, mindcf=1.0)) return # predictions, ground truth, and the false acceptance rates to calculate results = tuneThresholdfromScore(scores, labels, [1, 0.1]) eer = results[1] fnrs, fprs, thresholds = ComputeErrorRates(scores, labels) # p_target, c_miss, and c_falsealarm in NIST minDCF calculation p_trg, c_miss, c_fa = 0.05, 1, 1 mindcf, _ = ComputeMinDcf(fnrs, fprs, thresholds, p_trg, c_miss, c_fa) reporter.register( stats=dict( eer=eer, mindcf=mindcf, n_trials=n_trials, trg_mean=trg_mean, trg_std=trg_std, nontrg_mean=nontrg_mean, nontrg_std=nontrg_std, ) ) # added to reduce GRAM usage. May have minor speed boost when # this line is commented in case GRAM is not fully used. torch.cuda.empty_cache()
[docs] @classmethod @torch.no_grad() @typechecked def extract_embed( cls, model: torch.nn.Module, iterator: Iterable[Dict[str, torch.Tensor]], reporter: SubReporter, options: TrainerOptions, distributed_option: DistributedOption, output_dir: str, custom_bs: int, average: bool = False, ) -> None: ngpu = options.ngpu distributed = distributed_option.distributed model.eval() spk_embd_dic = {} # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished # iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") # fill dictionary with speech samples utt_id_list = [] utt_id_whole_list = [] speech_list = [] task_token = None if distributed: rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() else: rank = 0 world_size = 1 idx = 0 for utt_id, batch in iterator: if "task_tokens" in batch: task_token = batch["task_tokens"][0] assert isinstance(batch, dict), type(batch) for _utt_id, _speech, _speech2 in zip( utt_id, batch["speech"], batch["speech2"] ): _utt_id_1, _utt_id_2 = _utt_id.split("*") if _utt_id_1 not in utt_id_whole_list: utt_id_whole_list.append(_utt_id_1) if idx % world_size == rank: utt_id_list.append(_utt_id_1) speech_list.append(_speech) if len(utt_id_list) == custom_bs: speech_list = torch.stack(speech_list, dim=0) org_shape = (speech_list.size(0), speech_list.size(1)) speech_list = speech_list.flatten(0, 1) speech_list = to_device( speech_list, "cuda" if ngpu > 0 else "cpu" ) if task_token is None: task_tokens = None else: task_tokens = to_device( task_token.repeat(speech_list.size(0)), "cuda" if ngpu > 0 else "cpu", ).unsqueeze(1) spk_embds = model( speech=speech_list, spk_labels=None, extract_embd=True, task_tokens=task_tokens, ) # removed to be use magnitude in qmf # spk_embds = F.normalize(spk_embds, p=2, dim=1) spk_embds = spk_embds.view(org_shape[0], org_shape[1], -1) for uid, _spk_embd in zip(utt_id_list, spk_embds): if average: spk_embd_dic[uid] = ( _spk_embd.mean(0).detach().cpu().numpy() ) else: spk_embd_dic[uid] = _spk_embd.detach().cpu().numpy() utt_id_list = [] speech_list = [] idx += 1 if _utt_id_2 not in utt_id_whole_list: utt_id_whole_list.append(_utt_id_2) if idx % world_size == rank: utt_id_list.append(_utt_id_2) speech_list.append(_speech2) if len(utt_id_list) == custom_bs: speech_list = torch.stack(speech_list, dim=0) org_shape = (speech_list.size(0), speech_list.size(1)) speech_list = speech_list.flatten(0, 1) speech_list = to_device( speech_list, "cuda" if ngpu > 0 else "cpu" ) if task_token is None: task_tokens = None else: task_tokens = to_device( task_token.repeat(speech_list.size(0)), "cuda" if ngpu > 0 else "cpu", ).unsqueeze(1) spk_embds = model( speech=speech_list, spk_labels=None, extract_embd=True, task_tokens=task_tokens, ) # removed to be use magnitude in qmf # spk_embds = F.normalize(spk_embds, p=2, dim=1) spk_embds = spk_embds.view(org_shape[0], org_shape[1], -1) for uid, _spk_embd in zip(utt_id_list, spk_embds): if average: spk_embd_dic[uid] = ( _spk_embd.mean(0).detach().cpu().numpy() ) else: spk_embd_dic[uid] = _spk_embd.detach().cpu().numpy() utt_id_list = [] speech_list = [] idx += 1 if len(utt_id_list) != 0: speech_list = torch.stack(speech_list, dim=0) org_shape = (speech_list.size(0), speech_list.size(1)) speech_list = speech_list.flatten(0, 1) speech_list = to_device(speech_list, "cuda" if ngpu > 0 else "cpu") if task_token is None: task_tokens = None else: task_tokens = to_device( task_token.repeat(speech_list.size(0)), "cuda" if ngpu > 0 else "cpu", ).unsqueeze(1) spk_embds = model( speech=speech_list, spk_labels=None, extract_embd=True, task_tokens=task_tokens, ) spk_embds = F.normalize(spk_embds, p=2, dim=1) spk_embds = spk_embds.view(org_shape[0], org_shape[1], -1) for uid, _spk_embd in zip(utt_id_list, spk_embds): if average: spk_embd_dic[uid] = _spk_embd.mean(0).detach().cpu().numpy() else: spk_embd_dic[uid] = _spk_embd.detach().cpu().numpy() np.savez(output_dir + f"/embeddings{rank}", **spk_embd_dic)