#!/usr/bin/env python
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""pytorch dataset and dataloader implementation for chainer training."""
import torch
import torch.utils.data
[docs]class ChainerDataLoader(object):
"""Pytorch dataloader in chainer style.
Args:
all args for torch.utils.data.dataloader.Dataloader
"""
[docs] @staticmethod
def get_first_element(x):
"""Get first element of a given array-like object."""
return x[0]
def __init__(self, **kwargs):
"""Init function."""
self.loader = torch.utils.data.dataloader.DataLoader(**kwargs)
if hasattr(self.loader, "__len__"):
# To support DistribtedSampler.
# When using DDP, the size of dataset itself is different from
# the size returned by DataLoader.
# Unless using length of dataloader, at the end of iterations,
# this loader class can't recognize the end of each epoch.
self.len = len(self.loader)
else:
self.len = len(kwargs["dataset"])
self.current_position = 0
self.epoch = 0
self.iter = None
self.kwargs = kwargs
[docs] def next(self):
"""Implement next function."""
if self.iter is None:
self.iter = iter(self.loader)
try:
ret = next(self.iter)
except StopIteration:
self.iter = None
return self.next()
self.current_position += 1
if self.current_position == self.len:
self.epoch = self.epoch + 1
self.current_position = 0
return ret
def __iter__(self):
"""Implement iter function."""
for batch in self.loader:
yield batch
@property
def epoch_detail(self):
"""Epoch_detail required by chainer."""
return self.epoch + self.current_position / self.len
[docs] def serialize(self, serializer):
"""Serialize and deserialize function."""
epoch = serializer("epoch", self.epoch)
current_position = serializer("current_position", self.current_position)
self.epoch = epoch
self.current_position = current_position
[docs] def start_shuffle(self):
"""Shuffle function for sortagrad."""
self.kwargs["shuffle"] = True
self.loader = torch.utils.data.dataloader.DataLoader(**self.kwargs)
[docs] def finalize(self):
"""Implement finalize function."""
del self.loader