Source code for espnet2.iterators.sequence_iter_factory

import itertools
import random
from functools import partial
from typing import Any, Optional, Sequence, Union

import numpy as np
from torch.utils.data import DataLoader
from typeguard import typechecked

from espnet2.iterators.abs_iter_factory import AbsIterFactory
from espnet2.samplers.abs_sampler import AbsSampler


[docs]def worker_init_fn(worker_id, base_seed=0): """Set random seed for each worker in DataLoader.""" seed = base_seed + worker_id random.seed(seed) np.random.seed(seed)
[docs]class RawSampler(AbsSampler): def __init__(self, batches): self.batches = batches def __len__(self): return len(self.batches) def __iter__(self): return iter(self.batches)
[docs] def generate(self, seed): return list(self.batches)
[docs]class SequenceIterFactory(AbsIterFactory): """Build iterator for each epoch. This class simply creates pytorch DataLoader except for the following points: - The random seed is decided according to the number of epochs. This feature guarantees reproducibility when resuming from middle of training process. - Enable to restrict the number of samples for one epoch. This features controls the interval number between training and evaluation. """ @typechecked def __init__( self, dataset, batches: Union[AbsSampler, Sequence[Sequence[Any]]], num_iters_per_epoch: Optional[int] = None, seed: int = 0, shuffle: bool = False, shuffle_within_batch: bool = False, num_workers: int = 0, collate_fn=None, pin_memory: bool = False, ): if not isinstance(batches, AbsSampler): self.sampler = RawSampler(batches) else: self.sampler = batches self.dataset = dataset self.num_iters_per_epoch = num_iters_per_epoch self.shuffle = shuffle self.shuffle_within_batch = shuffle_within_batch self.seed = seed self.num_workers = num_workers self.collate_fn = collate_fn # https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702 self.pin_memory = pin_memory
[docs] def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader: if shuffle is None: shuffle = self.shuffle if self.num_iters_per_epoch is not None: N = len(self.sampler) # If corpus size is larger than the num_per_epoch if self.num_iters_per_epoch < N: N = len(self.sampler) real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N) if offset >= self.num_iters_per_epoch: current_batches = self.sampler.generate(real_epoch + self.seed) if shuffle: np.random.RandomState(real_epoch + self.seed).shuffle( current_batches ) batches = current_batches[ offset - self.num_iters_per_epoch : offset ] else: prev_batches = self.sampler.generate(real_epoch - 1 + self.seed) current_batches = self.sampler.generate(real_epoch + self.seed) if shuffle: np.random.RandomState(real_epoch - 1 + self.seed).shuffle( prev_batches ) np.random.RandomState(real_epoch + self.seed).shuffle( current_batches ) batches = ( prev_batches[offset - self.num_iters_per_epoch :] + current_batches[:offset] ) # If corpus size is less than the num_per_epoch else: _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N) _remain = self.num_iters_per_epoch batches = [] current_batches = self.sampler.generate(_epoch + self.seed) if shuffle: np.random.RandomState(_epoch + self.seed).shuffle(current_batches) while _remain > 0: _batches = current_batches[_cursor : _cursor + _remain] batches += _batches if _cursor + _remain >= N: _epoch += 1 _cursor = 0 current_batches = self.sampler.generate(_epoch + self.seed) if shuffle: np.random.RandomState(_epoch + self.seed).shuffle( current_batches ) else: _cursor = _cursor + _remain _remain -= len(_batches) assert len(batches) == self.num_iters_per_epoch else: batches = self.sampler.generate(epoch + self.seed) if shuffle: np.random.RandomState(epoch + self.seed).shuffle(batches) # For backward compatibility for pytorch DataLoader if self.collate_fn is not None: kwargs = dict(collate_fn=self.collate_fn) else: kwargs = {} # reshuffle whole 'batches' so that elements within a batch can move # between different batches if self.shuffle_within_batch: _bs = len(batches[0]) batches = list(itertools.chain(*batches)) np.random.RandomState(epoch + self.seed).shuffle(batches) _batches = [] for ii in range(0, len(batches), _bs): _batches.append(batches[ii : ii + _bs]) batches = _batches del _batches return DataLoader( dataset=self.dataset, batch_sampler=batches, num_workers=self.num_workers, pin_memory=self.pin_memory, worker_init_fn=partial(worker_init_fn, base_seed=epoch + self.seed), **kwargs, )