espnet.utils package

Initialize sub package.

espnet.utils.fill_missing_args

espnet.utils.fill_missing_args.fill_missing_args(args, add_arguments)[source]

Fill missing arguments in args.

Parameters
  • args (Namespace or None) – Namesapce containing hyperparameters.

  • add_arguments (function) – Function to add arguments.

Returns

Arguments whose missing ones are filled with default value.

Return type

Namespace

Examples

>>> from argparse import Namespace
>>> from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
>>> args = Namespace()
>>> fill_missing_args(args, Tacotron2.add_arguments_fn)
Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...)

espnet.utils.spec_augment

This implementation is modified from https://github.com/zcaceres/spec_augment

MIT License

Copyright (c) 2019 Zach Caceres

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETjjHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

espnet.utils.spec_augment.apply_interpolation(query_points, train_points, w, v, order)[source]

Apply polyharmonic interpolation model to data.

Notes

Given coefficients w and v for the interpolation model, we evaluate interpolated function values at query_points.

Parameters
  • query_points[b, m, d] x values to evaluate the interpolation at

  • train_points[b, n, d] x values that act as the interpolation centers ( the c variables in the wikipedia article) w: [b, n, k] weights on each interpolation center v: [b, d, k] weights on each input dimension

  • order – order of the interpolation

Returns

Polyharmonic interpolation evaluated at points defined in query_points.

espnet.utils.spec_augment.create_dense_flows(flattened_flows, batch_size, image_height, image_width)[source]
espnet.utils.spec_augment.cross_squared_distance_matrix(x, y)[source]

Pairwise squared distance between two (batch) matrices’ rows (2nd dim).

Computes the pairwise distances between rows of x and rows of y Args: x: [batch_size, n, d] float Tensor y: [batch_size, m, d] float Tensor Returns: squared_dists: [batch_size, n, m] float Tensor, where squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2

espnet.utils.spec_augment.dense_image_warp(image, flow)[source]

Image warping using per-pixel flow vectors.

Apply a non-linear warp to the image, where the warp is specified by a dense flow field of offset vectors that define the correspondences of pixel values in the output image back to locations in the source image. Specifically, the pixel value at output[b, j, i, c] is images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. The locations specified by this formula do not necessarily map to an int index. Therefore, the pixel value is obtained by bilinear interpolation of the 4 nearest pixels around (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside of the image, we use the nearest pixel values at the image boundary. Args: image: 4-D float Tensor with shape [batch, height, width, channels]. flow: A 4-D float Tensor with shape [batch, height, width, 2]. name: A name for the operation (optional). Note that image and flow can be of type tf.half, tf.float32, or tf.float64, and do not necessarily have to be the same type. Returns: A 4-D float Tensor with shape`[batch, height, width, channels]` and same type as input image. Raises: ValueError: if height < 2 or width < 2 or the inputs have the wrong number of dimensions.

espnet.utils.spec_augment.flatten_grid_locations(grid_locations, image_height, image_width)[source]
espnet.utils.spec_augment.freq_mask(spec, F=30, num_masks=1, replace_with_zero=False)[source]

Frequency masking

Parameters
  • spec (torch.Tensor) – input tensor with shape (T, dim)

  • F (int) – maximum width of each mask

  • num_masks (int) – number of masks

  • replace_with_zero (bool) – if True, masked parts will be filled with 0, if False, filled with mean

espnet.utils.spec_augment.get_flat_grid_locations(image_height, image_width, device)[source]
espnet.utils.spec_augment.get_grid_locations(image_height, image_width, device)[source]
espnet.utils.spec_augment.interpolate_bilinear(grid, query_points, name='interpolate_bilinear', indexing='ij')[source]

Similar to Matlab’s interp2 function.

Notes

Finds values for query points on a grid using bilinear interpolation.

Parameters
  • grid – a 4-D float Tensor of shape [batch, height, width, channels].

  • query_points – a 3-D float Tensor of N points with shape [batch, N, 2].

  • name – a name for the operation (optional).

  • indexing – whether the query points are specified as row and column (ij), or Cartesian coordinates (xy).

Returns

a 3-D Tensor with shape [batch, N, channels]

Return type

values

Raises
  • ValueError – if the indexing mode is invalid, or if the shape of the inputs

  • invalid.

espnet.utils.spec_augment.interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0)[source]
espnet.utils.spec_augment.phi(r, order)[source]

Coordinate-wise nonlinearity used to define the order of the interpolation.

See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. Args: r: input op order: interpolation order Returns: phi_k evaluated coordinate-wise on r, for k = r

espnet.utils.spec_augment.solve_interpolation(train_points, train_values, order, regularization_weight)[source]
espnet.utils.spec_augment.sparse_image_warp(img_tensor, source_control_point_locations, dest_control_point_locations, interpolation_order=2, regularization_weight=0.0, num_boundaries_points=0)[source]
espnet.utils.spec_augment.specaug(spec, W=5, F=30, T=40, num_freq_masks=2, num_time_masks=2, replace_with_zero=False)[source]

SpecAugment

Reference:

SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition (https://arxiv.org/pdf/1904.08779.pdf)

This implementation modified from https://github.com/zcaceres/spec_augment

Parameters
  • spec (torch.Tensor) – input tensor with the shape (T, dim)

  • W (int) – time warp parameter

  • F (int) – maximum width of each freq mask

  • T (int) – maximum width of each time mask

  • num_freq_masks (int) – number of frequency masks

  • num_time_masks (int) – number of time masks

  • replace_with_zero (bool) – if True, masked parts will be filled with 0, if False, filled with mean

espnet.utils.spec_augment.time_mask(spec, T=40, num_masks=1, replace_with_zero=False)[source]

Time masking

Parameters
  • spec (torch.Tensor) – input tensor with shape (T, dim)

  • T (int) – maximum width of each mask

  • num_masks (int) – number of masks

  • replace_with_zero (bool) – if True, masked parts will be filled with 0, if False, filled with mean

espnet.utils.spec_augment.time_warp(spec, W=5)[source]

Time warping

Parameters
  • spec (torch.Tensor) – input tensor with shape (T, dim)

  • W (int) – time warp parameter

espnet.utils.cli_readers

class espnet.utils.cli_readers.HDF5Reader(rspecifier, return_shape=False)[source]

Bases: object

class espnet.utils.cli_readers.KaldiReader(rspecifier, return_shape=False, segments=None)[source]

Bases: object

class espnet.utils.cli_readers.SoundHDF5Reader(rspecifier, return_shape=False)[source]

Bases: object

class espnet.utils.cli_readers.SoundReader(rspecifier, return_shape=False)[source]

Bases: object

espnet.utils.cli_readers.file_reader_helper(rspecifier: str, filetype: str = 'mat', return_shape: bool = False, segments: str = None)[source]

Read uttid and array in kaldi style

This function might be a bit confusing as “ark” is used for HDF5 to imitate “kaldi-rspecifier”.

Parameters
  • rspecifier – Give as “ark:feats.ark” or “scp:feats.scp”

  • filetype – “mat” is kaldi-martix, “hdf5”: HDF5

  • return_shape – Return the shape of the matrix, instead of the matrix. This can reduce IO cost for HDF5.

Returns

Return type

Generator[Tuple[str, np.ndarray], None, None]

Examples

Read from kaldi-matrix ark file:

>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
...     array

Read from HDF5 file:

>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
...     array

espnet.utils.io_utils

class espnet.utils.io_utils.LoadInputsAndTargets(mode='asr', preprocess_conf=None, load_input=True, load_output=True, sort_in_input_length=True, use_speaker_embedding=False, use_second_target=False, preprocess_args=None, keep_all_data_on_mem=False)[source]

Bases: object

Create a mini-batch from a list of dicts

>>> batch = [('utt1',
...           dict(input=[dict(feat='some.ark:123',
...                            filetype='mat',
...                            name='input1',
...                            shape=[100, 80])],
...                output=[dict(tokenid='1 2 3 4',
...                             name='target1',
...                             shape=[4, 31])]]))
>>> l = LoadInputsAndTargets()
>>> feat, target = l(batch)
Param

str mode: Specify the task mode, “asr” or “tts”

Param

str preprocess_conf: The path of a json file for pre-processing

Param

bool load_input: If False, not to load the input data

Param

bool load_output: If False, not to load the output data

Param

bool sort_in_input_length: Sort the mini-batch in descending order of the input length

Param

bool use_speaker_embedding: Used for tts mode only

Param

bool use_second_target: Used for tts mode only

Param

dict preprocess_args: Set some optional arguments for preprocessing

Param

Optional[dict] preprocess_args: Used for tts mode only

class espnet.utils.io_utils.SoundHDF5File(filepath, mode='r+', format=None, dtype='int16', **kwargs)[source]

Bases: object

Collecting sound files to a HDF5 file

>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
Param

str filepath:

Param

str mode:

Param

str format: The type used when saving wav. flac, nist, htk, etc.

Param

str dtype:

close()[source]
create_dataset(name, shape=None, data=None, **kwds)[source]
items()[source]
keys()[source]
values()[source]

espnet.utils.deterministic_utils

espnet.utils.deterministic_utils.set_deterministic_chainer(args)[source]

Ensures chainer produces deterministic results depending on the program arguments

Parameters

args (Namespace) – The program arguments

espnet.utils.deterministic_utils.set_deterministic_pytorch(args)[source]

Ensures pytorch produces deterministic results depending on the program arguments

Parameters

args (Namespace) – The program arguments

espnet.utils.check_kwargs

espnet.utils.check_kwargs.check_kwargs(func, kwargs, name=None)[source]

check kwargs are valid for func

If kwargs are invalid, raise TypeError as same as python default :param function func: function to be validated :param dict kwargs: keyword arguments for func :param str name: name used in TypeError (default is func name)

espnet.utils.dynamic_import

espnet.utils.dynamic_import.dynamic_import(import_path, alias={})[source]

dynamic import module and class

Parameters
  • import_path (str) – syntax ‘module_name:class_name’ e.g., ‘espnet.transform.add_deltas:AddDeltas’

  • alias (dict) – shortcut for registered class

Returns

imported class

espnet.utils.cli_writers

class espnet.utils.cli_writers.BaseWriter[source]

Bases: object

close()[source]
class espnet.utils.cli_writers.HDF5Writer(wspecifier, write_num_frames=None, compress=False)[source]

Bases: espnet.utils.cli_writers.BaseWriter

Examples

>>> with HDF5Writer('ark:out.h5', compress=True) as f:
...     f['key'] = array
class espnet.utils.cli_writers.KaldiWriter(wspecifier, write_num_frames=None, compress=False, compression_method=2)[source]

Bases: espnet.utils.cli_writers.BaseWriter

class espnet.utils.cli_writers.SoundHDF5Writer(wspecifier, write_num_frames=None, pcm_format='wav')[source]

Bases: espnet.utils.cli_writers.BaseWriter

Examples

>>> fs = 16000
>>> with SoundHDF5Writer('ark:out.h5') as f:
...     f['key'] = fs, array
class espnet.utils.cli_writers.SoundWriter(wspecifier, write_num_frames=None, pcm_format='wav')[source]

Bases: espnet.utils.cli_writers.BaseWriter

Examples

>>> fs = 16000
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
...     f['key'] = fs, array
espnet.utils.cli_writers.file_writer_helper(wspecifier: str, filetype: str = 'mat', write_num_frames: str = None, compress: bool = False, compression_method: int = 2, pcm_format: str = 'wav')[source]

Write matrices in kaldi style

Parameters
  • wspecifier – e.g. ark,scp:out.ark,out.scp

  • filetype – “mat” is kaldi-martix, “hdf5”: HDF5

  • write_num_frames – e.g. ‘ark,t:num_frames.txt’

  • compress – Compress or not

  • compression_method – Specify compression level

Write in kaldi-matrix-ark with “kaldi-scp” file:

>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
>>>     f['uttid'] = array

This “scp” has the following format:

uttidA out.ark:1234 uttidB out.ark:2222

where, 1234 and 2222 points the strating byte address of the matrix. (For detail, see official documentation of Kaldi)

Write in HDF5 with “scp” file:

>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
>>>     f['uttid'] = array

This “scp” file is created as:

uttidA out.h5:uttidA uttidB out.h5:uttidB

HDF5 can be, unlike “kaldi-ark”, accessed to any keys, so originally “scp” is not required for random-reading. Nevertheless we create “scp” for HDF5 because it is useful for some use-case. e.g. Concatenation, Splitting.

espnet.utils.cli_writers.get_num_frames_writer(write_num_frames: str)[source]

Examples

>>> get_num_frames_writer('ark,t:num_frames.txt')
espnet.utils.cli_writers.parse_wspecifier(wspecifier: str) → Dict[str, str][source]

Parse wspecifier to dict

Examples

>>> parse_wspecifier('ark,scp:out.ark,out.scp')
{'ark': 'out.ark', 'scp': 'out.scp'}

espnet.utils.dataset

pytorch dataset and dataloader implementation for chainer training.

class espnet.utils.dataset.ChainerDataLoader(**kwargs)[source]

Bases: object

Pytorch dataloader in chainer style.

Parameters

args for torch.utils.data.dataloader.Dataloader (all) –

Init function.

property epoch_detail

Epoch_detail required by chainer.

finalize()[source]

Implement finalize function.

next()[source]

Implement next function.

serialize(serializer)[source]

Serialize and deserialize function.

start_shuffle()[source]

Shuffle function for sortagrad.

class espnet.utils.dataset.TransformDataset(data, transform)[source]

Bases: torch.utils.data.dataset.Dataset

Transform Dataset for pytorch backend.

Parameters
  • data – list object from make_batchset

  • transfrom – transform function

Init function.

espnet.utils.__init__

Initialize sub package.

espnet.utils.cli_utils

espnet.utils.cli_utils.assert_scipy_wav_style(value)[source]
espnet.utils.cli_utils.get_commandline_args()[source]
espnet.utils.cli_utils.is_scipy_wav_style(value)[source]
espnet.utils.cli_utils.strtobool(x)[source]

espnet.utils.training.batchfy

espnet.utils.training.batchfy.batchfy_by_bin(sorted_data, batch_bins, num_batches=0, min_batch_size=1, shortest_first=False, ikey='input', okey='output')[source]

Make variably sized batch set, which maximizes

the number of bins up to batch_bins.

Parameters
  • Dict[str, Any]] sorted_data (Dict[str,) – dictionary loaded from data.json

  • batch_bins (int) – Maximum frames of a batch

  • num_batches (int) – # number of batches to use (for debug)

  • min_batch_size (int) – minimum batch size (for multi-gpu)

  • test (int) – Return only every test batches

  • shortest_first (bool) – Sort from batch with shortest samples to longest if true, otherwise reverse

  • ikey (str) – key to access input (for ASR ikey=”input”, for TTS ikey=”output”.)

  • okey (str) – key to access output (for ASR okey=”output”. for TTS okey=”input”.)

Returns

List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches

espnet.utils.training.batchfy.batchfy_by_frame(sorted_data, max_frames_in, max_frames_out, max_frames_inout, num_batches=0, min_batch_size=1, shortest_first=False, ikey='input', okey='output')[source]

Make variable batch set, which maximizes the number of frames to max_batch_frame.

Parameters
  • Dict[str, Any]] sorteddata (Dict[str,) – dictionary loaded from data.json

  • max_frames_in (int) – Maximum input frames of a batch

  • max_frames_out (int) – Maximum output frames of a batch

  • max_frames_inout (int) – Maximum input+output frames of a batch

  • num_batches (int) – # number of batches to use (for debug)

  • min_batch_size (int) – minimum batch size (for multi-gpu)

  • test (int) – Return only every test batches

  • shortest_first (bool) – Sort from batch with shortest samples to longest if true, otherwise reverse

  • ikey (str) – key to access input (for ASR ikey=”input”, for TTS ikey=”output”.)

  • okey (str) – key to access output (for ASR okey=”output”. for TTS okey=”input”.)

Returns

List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches

espnet.utils.training.batchfy.batchfy_by_seq(sorted_data, batch_size, max_length_in, max_length_out, min_batch_size=1, shortest_first=False, ikey='input', iaxis=0, okey='output', oaxis=0)[source]

Make batch set from json dictionary

Parameters
  • Dict[str, Any]] sorted_data (Dict[str,) – dictionary loaded from data.json

  • batch_size (int) – batch size

  • max_length_in (int) – maximum length of input to decide adaptive batch size

  • max_length_out (int) – maximum length of output to decide adaptive batch size

  • min_batch_size (int) – mininum batch size (for multi-gpu)

  • shortest_first (bool) – Sort from batch with shortest samples to longest if true, otherwise reverse

  • ikey (str) – key to access input (for ASR ikey=”input”, for TTS, MT ikey=”output”.)

  • iaxis (int) – dimension to access input (for ASR, TTS iaxis=0, for MT iaxis=”1”.)

  • okey (str) – key to access output (for ASR, MT okey=”output”. for TTS okey=”input”.)

  • oaxis (int) – dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)

Returns

List[List[Tuple[str, dict]]] list of batches

espnet.utils.training.batchfy.batchfy_shuffle(data, batch_size, min_batch_size, num_batches, shortest_first)[source]
espnet.utils.training.batchfy.make_batchset(data, batch_size=0, max_length_in=inf, max_length_out=inf, num_batches=0, min_batch_size=1, shortest_first=False, batch_sort_key='input', swap_io=False, mt=False, count='auto', batch_bins=0, batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, iaxis=0, oaxis=0)[source]

Make batch set from json dictionary

if utts have “category” value,

>>> data = {'utt1': {'category': 'A', 'input': ...},
...         'utt2': {'category': 'B', 'input': ...},
...         'utt3': {'category': 'B', 'input': ...},
...         'utt4': {'category': 'A', 'input': ...}}
>>> make_batchset(data, batchsize=2, ...)
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]

Note that if any utts doesn’t have “category”, perform as same as batchfy_by_{count}

Parameters
  • Dict[str, Any]] data (Dict[str,) – dictionary loaded from data.json

  • batch_size (int) – maximum number of sequences in a minibatch.

  • batch_bins (int) – maximum number of bins (frames x dim) in a minibatch.

  • batch_frames_in (int) – maximum number of input frames in a minibatch.

  • batch_frames_out (int) – maximum number of output frames in a minibatch.

  • batch_frames_out – maximum number of input+output frames in a minibatch.

  • count (str) – strategy to count maximum size of batch. For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES

  • max_length_in (int) – maximum length of input to decide adaptive batch size

  • max_length_out (int) – maximum length of output to decide adaptive batch size

  • num_batches (int) – # number of batches to use (for debug)

  • min_batch_size (int) – minimum batch size (for multi-gpu)

  • shortest_first (bool) – Sort from batch with shortest samples to longest if true, otherwise reverse

  • batch_sort_key (str) – how to sort data before creating minibatches [“input”, “output”, “shuffle”]

  • swap_io (bool) – if True, use “input” as output and “output” as input in data dict

  • mt (bool) – if True, use 0-axis of “output” as output and 1-axis of “output” as input in data dict

  • iaxis (int) – dimension to access input (for ASR, TTS iaxis=0, for MT iaxis=”1”.)

  • oaxis (int) – dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)

Returns

List[List[Tuple[str, dict]]] list of batches

espnet.utils.training.train_utils

espnet.utils.training.train_utils.check_early_stop(trainer, epochs)[source]

Checks an early stopping trigger and warns the user if it’s the case

Parameters
  • trainer – The trainer used for training

  • epochs – The maximum number of epochs

espnet.utils.training.train_utils.set_early_stop(trainer, args, is_lm=False)[source]

Sets the early stop trigger given the program arguments

Parameters
  • trainer – The trainer used for training

  • args – The program arguments

  • is_lm – If the trainer is for a LM (epoch instead of epochs)

espnet.utils.training.evaluator

class espnet.utils.training.evaluator.BaseEvaluator(iterator, target, converter=<function concat_examples>, device=None, eval_hook=None, eval_func=None)[source]

Bases: chainer.training.extensions.evaluator.Evaluator

Base Evaluator in ESPnet

espnet.utils.training.tensorboard_logger

class espnet.utils.training.tensorboard_logger.TensorboardLogger(logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0)[source]

Bases: chainer.training.extension.Extension

A tensorboard logger extension

Init the extension

Parameters
  • logger (SummaryWriter) – The logger to use

  • att_reporter (PlotAttentionReporter) – The (optional) PlotAttentionReporter

  • entries – The entries to watch

  • epoch (int) – The starting epoch

default_name = 'espnet_tensorboard_logger'

espnet.utils.training.__init__

Initialize sub package.

espnet.utils.training.iterators

class espnet.utils.training.iterators.ShufflingEnabler(iterators)[source]

Bases: chainer.training.extension.Extension

An extension enabling shuffling on an Iterator

Inits the ShufflingEnabler

Parameters

iterators (list[Iterator]) – The iterators to enable shuffling on

class espnet.utils.training.iterators.ToggleableShufflingMultiprocessIterator(dataset, batch_size, repeat=True, shuffle=True, n_processes=None, n_prefetch=1, shared_mem=None, maxtasksperchild=20)[source]

Bases: chainer.iterators.multiprocess_iterator.MultiprocessIterator

A MultiprocessIterator having its shuffling property activated during training

Init the iterator

Parameters
  • dataset (torch.nn.Tensor) – The dataset to take batches from

  • batch_size (int) – The batch size

  • repeat (bool) – Whether to repeat batches or not (enables multiple epochs)

  • shuffle (bool) – Whether to shuffle the order of the batches

  • n_processes (int) – How many processes to use

  • n_prefetch (int) – The number of prefetch to use

  • shared_mem (int) – How many memory to share between processes

  • maxtasksperchild (int) – Maximum number of tasks per child

start_shuffle()[source]

Starts shuffling (or reshuffles) the batches

class espnet.utils.training.iterators.ToggleableShufflingSerialIterator(dataset, batch_size, repeat=True, shuffle=True)[source]

Bases: chainer.iterators.serial_iterator.SerialIterator

A SerialIterator having its shuffling property activated during training

Init the Iterator

Parameters
  • dataset (torch.nn.Tensor) – The dataset to take batches from

  • batch_size (int) – The batch size

  • repeat (bool) – Whether to repeat data (allow multiple epochs)

  • shuffle (bool) – Whether to shuffle the batches

start_shuffle()[source]

Starts shuffling (or reshuffles) the batches