Source code for benzina.torch.dataloader

# -*- coding: utf-8 -*-
import benzina.native
import gc
import numpy                           as np
import torch

from import default_collate
from   contextlib                  import suppress
from   .                           import operations as ops

from   benzina.utils.file import Track

[docs]class DataLoader( """ Loads images from a :class:`benzina.torch.dataset.Dataset`. Encapsulates a sampler and data processing transformations. Args: dataset (:class:`benzina.torch.dataset.Dataset`): dataset from which to load the data. shape (int or tuple of ints): set the shape of the samples. Note that this does not imply a resize of the image but merely set the shape of the tensor in which the data will be copied. path (str, optional): path to the archive from which samples will be decoded. If not specified, the dataloader will attempt to get it from :attr:`dataset`. batch_size (int, optional): how many samples per batch to load. (default: ``1``) shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch. (default: ``False``) sampler (, optional): defines the strategy to draw samples from the dataset. If specified, :attr:`shuffle` must be ``False``. batch_sampler (, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. collate_fn (callable, optional): merges a list of samples to form a mini-batch. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) timeout (numeric, optional): if positive, the timeout value for collecting a batch. Should always be non-negative. (default: ``0``) device (torch.device, optional): set the device to use. Note that only CUDA devices are supported for the moment. multibuffering (int, optional): set the size of the multibuffering buffer. (default: ``3``) seed (int, optional): set the seed for the random transformations. bias_transform (:class:`benzina.torch.operations.BiasTransform` or float, optional): set the bias transformation. Values to substract a pixel's channels with. Note that this transformation is applied before :attr:`norm_transform`. norm_transform (:class:`benzina.torch.operations.NormTransform` or float or iterable of float, optional): set the normalization transformation. Values to multiply a pixel's channels with. Note that this transformation is applied after :attr:`bias_transform`. warp_transform (:class:`benzina.torch.operations.WarpTransform` or iterable of float, optional): set the warp transformation or use as the arguments to initialize a WarpTransform. """ def __init__(self, dataset, shape, path = None, batch_size = 1, shuffle = False, sampler = None, batch_sampler = None, collate_fn = default_collate, drop_last = False, timeout = 0, device = None, multibuffering = 3, seed = None, bias_transform = None, norm_transform = None, warp_transform = None): super().__init__(dataset, batch_size = batch_size, shuffle = shuffle, sampler = sampler, batch_sampler = batch_sampler, num_workers = 0, collate_fn = collate_fn, pin_memory = True, drop_last = drop_last, timeout = float(timeout), worker_init_fn = None) if isinstance(shape, int): shape = (shape, shape) if path is None: path = dataset.filename if seed is None: seed = torch.randint(low = 0, high = 2**32, size = (), dtype = torch.int64, device = "cpu") seed = int(seed) if not isinstance(warp_transform, ops.WarpTransform): warp_transform = ops.ConstantWarpTransform(warp_transform) if not isinstance(norm_transform, ops.NormTransform): norm_transform = ops.ConstantNormTransform(norm_transform) if not isinstance(bias_transform, ops.BiasTransform): bias_transform = ops.ConstantBiasTransform(bias_transform) self.path = path self.device = device self.multibuffering = multibuffering self.shape = shape self.RNG = np.random.RandomState(seed) self.warp_transform = warp_transform self.color_transform = ops.ConstantColorTransform() self.oob_transform = ops.ConstantOOBTransform() self.norm_transform = norm_transform self.bias_transform = bias_transform def __iter__(self): return _DataLoaderIter(self)
class _DataLoaderIter: def __init__(self, loader): assert(loader.multibuffering >= 1) self.length = len(loader) self.dataset = loader.dataset self.dataset_core = benzina.native.DatasetCore(loader.path, len(loader.dataset)) self.batch_size = loader.batch_size self.batch_iter = iter(loader.batch_sampler) self.multibuffering = loader.multibuffering self.shape = loader.shape self.collate_fn = loader.collate_fn self.drop_last = loader.drop_last self.timeout = loader.timeout if loader.device is None or loader.device == "cuda": self.device = torch.device(torch.cuda.current_device()) elif isinstance(loader.device, (str, int)): self.device = torch.device(loader.device) else: self.device = loader.device self.RNG = np.random.RandomState(loader.RNG.randint(2**32)) self.warp_transform = loader.warp_transform self.color_transform = loader.color_transform self.oob_transform = loader.oob_transform self.norm_transform = loader.norm_transform self.bias_transform = loader.bias_transform self.multibuffer = None self.core = None self.first_indices = None self.stop_iteration = None def __del__(self): """ Destroy the iterator and all its resources. Because extraneous and circular references can keep the large GPU multibuffer tensor allocated indefinitely, we: 1. Forcibly destroy all our members, thereby losing all of the iterator's possible references to the multibuffer and the iterator core. Tensor deallocations may or may not happen at this moment. 2. Invoke the garbage collector, which is capable of identifying cyclic trash and removing it. The iterator core object supports garbage collection and is capable of breaking all reference cycles involving it. 3. Empty the PyTorch CUDA cache, returning the CUDA memory buffers to the allocation pool. Because data loaders are not intended to be created extremely often, the extra cycles spent here doing this are worth it. """ del self.__dict__ self.garbage_collect() def __iter__(self): return self def __len__(self): return self.length def __next__(self): if self.stop_iteration is not None: raise self.stop_iteration try: if self.core_needs_init(): self.pull_first_indices() self.init_core() self.push_first_indices() with suppress(StopIteration): self.fill_core() else: with suppress(StopIteration): self.fill_one_batch() return self.pull() except StopIteration as si: self.stop_iteration = si self.garbage_collect() raise self.stop_iteration def core_needs_init(self): return self.core is None def pull_first_indices(self): self.first_indices = next(self.batch_iter) def init_core(self): """ Initialize the iterator core. From the first batch drawn from the sample iterator, we know the maximum batch size. We allocate a multibuffer large enough to containing self.multibuffering batches of the maximum size. Before we do so, however, we trigger garbage collection and empty the tensor cache, in an attempt to ensure circular references keeping previous large multibuffers alive have been destroyed. """ self.garbage_collect() self.check_or_set_batch_size(self.first_indices) self.multibuffer = torch.zeros([self.multibuffering, self.batch_size, 3, self.shape[0], self.shape[1]], dtype = torch.float32, device = self.device) self.core = benzina.native.NvdecodeDataLoaderIterCore( self.dataset_core, str(self.device), self.multibuffer, self.multibuffer.data_ptr(), self.batch_size, self.multibuffering, self.shape[0], self.shape[1], ) def push_first_indices(self): self.push(self.__dict__.pop("first_indices")) def fill_core(self): while self.core.pushes < self.core.multibuffering: self.fill_one_batch() def push(self, indices): self.check_or_set_batch_size(indices) buffer = self.multibuffer[self.core.pushes % self.core.multibuffering][:len(indices)] indices = [int(i) for i in indices] ptrs = [int(buffer[n].data_ptr()) for n in range(len(indices))] samples = [self.dataset[i] for i in indices] items, auxd = zip(*[((item.input, item.input_label), for item in samples]) # Use "bzna_thumb" until having a dataloader that is able to load # the size variant images in "bzna_input" inputs = [Track(item.as_file(), label) for item, label in items] token = (buffer, *self.collate_fn(auxd)) t_args = (self.shape, self.RNG) with self.core.batch(token) as batch: for i,ptr,input in zip(indices, ptrs, inputs): with batch.sample(i, ptr, input.sample_location(0), input.video_configuration_location()): self.core.setHomography (*self.warp_transform (i, input.shape, *t_args)) self.core.selectColorMatrix(*self.color_transform(i, input.shape, *t_args)) self.core.setBias (*self.bias_transform (i, input.shape, *t_args)) self.core.setScale (*self.norm_transform (i, input.shape, *t_args)) self.core.setOOBColor (*self.oob_transform (i, input.shape, *t_args)) def pull(self): if self.core.pulls >= self.core.pushes: raise StopIteration return self.core.waitBatch(block=True, timeout=self.timeout) def fill_one_batch(self): self.push(next(self.batch_iter)) def check_or_set_batch_size(self, indices): iter_batch_size = len(indices) if self.batch_size is None: self.batch_size = iter_batch_size elif self.batch_size < iter_batch_size: raise RuntimeError("Batch size expected to be {}, but iterator returned larger batch size {}!" .format(self.batch_size, iter_batch_size)) elif self.batch_size > iter_batch_size: if self.drop_last: raise StopIteration def garbage_collect(self): self.core = None self.multibuffer = None gc.collect() torch.cuda.empty_cache()