Source code for espnet2.samplers.category_balanced_sampler

# Sampler that keeps equally distributed categories (i.e., classes) within
# each minibatch. If the batch_size is smaller than the number of classes,
# all samples in the minibatch will belong to different classes.
# Cross-checked with https://github.com/clovaai/voxceleb_trainer/blob/master/
# DatasetLoader.py
# 'key_file' is just a text file which describes each sample name."
# \n\n"
#     utterance_id_a\n"
#     utterance_id_b\n"
#     utterance_id_c\n"
# \n"
# The fist column is referred, so 'shape file' can be used, too.\n\n"
#     utterance_id_a 100,80\n"
#     utterance_id_b 400,80\n"
#     utterance_id_c 512,80\n",
import random
from collections import Counter
from typing import Iterator, Optional, Tuple

from typeguard import typechecked

from espnet2.fileio.read_text import read_2columns_text
from espnet2.samplers.abs_sampler import AbsSampler


[docs]def round_down(num, divisor): return num - (num % divisor)
[docs]class CategoryBalancedSampler(AbsSampler): @typechecked def __init__( self, batch_size: int, min_batch_size: int = 1, drop_last: bool = False, category2utt_file: Optional[str] = None, epoch: int = 1, **kwargs, ): assert batch_size > 0 random.seed(epoch) self.batch_size = batch_size self.min_batch_size = min_batch_size self.drop_last = drop_last assert category2utt_file is not None # dictionary with categories as keys and corresponding utterances # as values category2utt = read_2columns_text(category2utt_file) categories = list(category2utt.keys()) random.shuffle(categories) if len(categories) >= self.batch_size: n_utt_per_category_in_batch = 1 else: n_utt_per_category_in_batch = int(self.batch_size / len(categories)) + 1 flattened_cats = [] for cat in categories: category2utt[cat] = category2utt[cat].split(" ") flattened_cats.extend([cat] * len(category2utt[cat])) random.shuffle(category2utt[cat]) rand_idx = list(range(len(flattened_cats))) random.shuffle(rand_idx) self.batch_list = [] current_batch = [] current_batch_stats = Counter() # make minibatches for idx in rand_idx: # don't allow more number of samples that belong to each category # then n_utt_per_category_in_batch if current_batch_stats[flattened_cats[idx]] >= n_utt_per_category_in_batch: continue current_batch.append(category2utt[flattened_cats[idx]].pop()) current_batch_stats[flattened_cats[idx]] += 1 # append batch to batch list if len(current_batch) == self.batch_size: self.batch_list.append(current_batch) current_batch = [] current_batch_stats = Counter() def __repr__(self): return ( f"{self.__class__.__name__}(" f"N-batch={len(self)}, " f"batch_size={self.batch_size}, " ) def __len__(self): return len(self.batch_list) def __iter__(self) -> Iterator[Tuple[str, ...]]: return iter(self.batch_list)