# -*- coding: utf-8 -*-
from collections import namedtuple
import os
import typing
import numpy as np
import torch.utils.data
from benzina.utils.file import Track
_TrackType = typing.Union[str, Track]
_TrackPairType = typing.Tuple[_TrackType, _TrackType]
_ClassTracksType = typing.Tuple[_TrackType, _TrackType]
[docs]class Dataset(torch.utils.data.Dataset):
"""
Args:
archive (str or :class:`Track`): path to the archive or a Track. If a
Track, :attr:`track` will be ignored.
track (str or :class:`Track`, optional): track label or a Track. If a
Track, :attr:`archive` must not be specified.
(default: ``"bzna_input"``)
"""
_Item = namedtuple("Item", ["input"])
def __init__(self,
archive: typing.Union[str, _TrackType] = None,
track: _TrackType = "bzna_input"):
if isinstance(archive, Track):
track = archive
archive = None
if archive is not None:
if not isinstance(track, str):
raise ValueError("track option must be a track label when "
"archive is specified.")
archive = os.path.expanduser(archive)
archive = os.path.expandvars(archive)
if not os.path.isfile(archive):
raise ValueError("The archive {} is not present.".format(archive))
track = Track(archive, track)
elif not isinstance(track, Track):
raise ValueError("track option must be a Track when archive is "
"not specified.")
self._track = track
self._track.open()
self._filename = track.file.path
@property
def filename(self):
return self._filename
def __len__(self):
return len(self._track)
def __getitem__(self, index: int):
return Dataset._Item(self._track[index])
def __add__(self, other):
raise NotImplementedError()
[docs]class ClassificationDataset(Dataset):
"""
Args:
archive (str or pair of :class:`Track`): path to the archive or a pair
of Track. If a pair of Track, :attr:`tracks` will be ignored.
tracks (pair of str or :class:`Track`, optional): pair of input and
target tracks labels or a pair of input and target Track. If a pair
of Track, :attr:`archive` must not be specified.
(default: ``("bzna_input", "bzna_target")``)
input_label (str, optional): label of the inputs to use in the input
track. (default: ``"bzna_thumb"``)
"""
_Item = namedtuple("Item", ["input", "input_label", "target"])
def __init__(self,
archive: typing.Union[str, _TrackPairType] = None,
tracks: _ClassTracksType = ("bzna_input", "bzna_target"),
input_label: str = "bzna_thumb"):
try:
archive, tracks, input_label = \
ClassificationDataset._validate_args(
None, archive, input_label)
except (TypeError, ValueError):
archive, tracks, input_label = \
ClassificationDataset._validate_args(
archive, tracks, input_label)
if archive is not None:
input_track = Track(archive, tracks[0])
target_track = Track(archive, tracks[1])
else:
input_track, target_track = tracks
Dataset.__init__(self, input_track)
self._input_label = input_label
target_track.open()
location_first, _ = target_track[0].location
location_last, size_last = target_track[-1].location
target_track.file.seek(location_first)
buffer = target_track.file.read(location_last + size_last - location_first)
self._targets = np.full(len(self._track), -1, np.int64)
self._targets[:len(target_track)] = np.frombuffer(buffer, np.dtype("<i8"))
def __getitem__(self, index: int):
item = Dataset.__getitem__(self, index)
return self._Item(input=item.input,
input_label=self._input_label,
target=(self.targets[index],))
def __add__(self, other):
raise NotImplementedError()
@property
def targets(self):
return self._targets
@staticmethod
def _validate_args(*args):
archive, tracks, input_label = args
if archive is not None:
if any(not isinstance(t, str) for t in tracks):
raise ValueError("tracks option must be a pair of tracks "
"labels when archive is specified.")
archive = os.path.expanduser(archive)
archive = os.path.expandvars(archive)
if not os.path.isfile(archive):
raise ValueError("The archive {} is not present.".format(archive))
_, _ = tracks
elif any(not isinstance(t, Track) for t in tracks):
raise ValueError("tracks option must be a pair of Track when "
"archive is not specified.")
return archive, tracks, input_label
[docs]class ImageNet(ClassificationDataset):
"""
Args:
root (str or pair of :class:`Track`): root of the ImageNet dataset or
path to the archive or a pair of Track. If a pair of Track,
:attr:`tracks` will be ignored.
split (None or str, optional): The dataset split, supports ``test``,
``train``, ``val``. If not specified, samples will be drawn from
all splits.
tracks (pair of str or :class:`Track`, optional): pair of input and
target tracks labels or a pair of input and target Track. If a pair
of Track, :attr:`root` must not be specified.
(default: ``("bzna_input", "bzna_target")``)
input_label (str, optional): label of the inputs to use in the input
track. (default: ``"bzna_thumb"``)
"""
# Some images are missing from the dataset. Please read the README of the
# dataset for more information.
LEN_VALID = 50000 - 1
LEN_TEST = 100000 - 7
def __init__(self,
root: typing.Union[str, _TrackPairType] = None,
split: str = None,
tracks: _ClassTracksType = ("bzna_input", "bzna_target"),
input_label: str = "bzna_thumb"):
try:
archive, split, tracks, input_label = \
ImageNet._validate_args(None, split, root, input_label)
except (TypeError, ValueError):
archive, split, tracks, input_label = \
ImageNet._validate_args(root, split, tracks, input_label)
ClassificationDataset.__init__(self, archive, tracks, input_label)
self._indices = np.array(range(ClassificationDataset.__len__(self)),
np.int64)
if split == "test":
self._indices = self._indices[-self.LEN_TEST:]
self._targets = self._targets[-self.LEN_TEST:]
elif split == "train":
len_train = len(self) - self.LEN_VALID - self.LEN_TEST
self._indices = self._indices[:len_train]
self._targets = self._targets[:len_train]
elif split == "val":
len_train = len(self) - self.LEN_VALID - self.LEN_TEST
self._indices = self._indices[len_train:-self.LEN_TEST]
self._targets = self._targets[len_train:-self.LEN_TEST]
def __getitem__(self, index: int):
item = Dataset.__getitem__(self, self._indices[index])
return ImageNet._Item(input=item.input,
input_label=self._input_label,
target=(self._targets[index],))
def __len__(self):
return len(self._indices)
def __add__(self, other):
raise NotImplementedError()
@staticmethod
def _validate_args(*args):
root, split, tracks, input_label = args
archive = None
if root is not None:
if any(not isinstance(t, str) for t in tracks):
raise ValueError("tracks option must be a pair of tracks "
"labels when root is specified.")
root = os.path.expanduser(root)
root = os.path.expandvars(root)
if os.path.isfile(root):
archive = root
elif os.path.isfile(os.path.join(root, "ilsvrc2012.bzna")):
archive = os.path.join(root, "ilsvrc2012.bzna")
elif os.path.isfile(os.path.join(root, "ilsvrc2012.mp4")):
archive = os.path.join(root, "ilsvrc2012.mp4")
if archive is None:
if root.endswith(".mp4") or root.endswith(".bzna"):
raise ValueError("The archive {} is not present.".format(root))
else:
raise ValueError("The archive ilsvrc2012.[mp4|bzna] is not "
"present in root {}.".format(root))
elif any(not isinstance(t, Track) for t in tracks):
raise ValueError("tracks option must be a pair of Track when "
"root is not specified.")
if split not in {"test", "train", "val", None}:
raise ValueError("split option must be one of test, train, val")
return archive, split, tracks, input_label