diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 153094fae07..6129f1892a1 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -16,7 +16,6 @@ from torch.testing import make_tensor as _make_tensor from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype import datasets -from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER from torchvision.prototype.datasets._api import find from torchvision.prototype.utils._internal import add_suggestion @@ -109,21 +108,15 @@ def _get(self, dataset, config, root): self._cache[(name, config)] = mock_resources, mock_info return mock_resources, mock_info - def load( - self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any - ) -> Tuple[IterDataPipe, Dict[str, Any]]: + def load(self, name: str, **options: Any) -> Tuple[IterDataPipe, Dict[str, Any]]: dataset = find(name) - config = dataset.info.make_config(split=split, **options) + config = dataset.info.make_config(**options) root = self._tmp_home / name root.mkdir(exist_ok=True) resources, mock_info = self._get(dataset, config, root) - datapipe = dataset._make_datapipe( - [resource.load(root) for resource in resources], - config=config, - decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder, - ) + datapipe = dataset._make_datapipe([resource.load(root) for resource in resources], config=config) return datapipe, mock_info diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 4248870176f..04f2dd90b66 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -1,18 +1,37 @@ +import functools import io import builtin_dataset_mocks import pytest import torch +from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair, UnsupportedInputs, ErrorMeta from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchvision.prototype import datasets, transforms -from torchvision.prototype.datasets._api import DEFAULT_DECODER from torchvision.prototype.utils._internal import sequence_to_str -def to_bytes(file): - return file.read() +# TODO: remove this patch after https://github.com/pytorch/pytorch/pull/70304 is merged +def patch(fn): + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except ErrorMeta as error: + if error.type is not ValueError: + raise error + + raise UnsupportedInputs() + + return wrapper + + +TensorLikePair._to_tensor = patch(TensorLikePair._to_tensor) + + +assert_samples_equal = functools.partial( + assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True +) def config_id(name, config): @@ -26,7 +45,7 @@ def config_id(name, config): return "-".join(parts) -def dataset_parametrization(*names, decoder=to_bytes): +def dataset_parametrization(*names): if not names: # TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported names = ( @@ -46,7 +65,7 @@ def dataset_parametrization(*names, decoder=to_bytes): return pytest.mark.parametrize( ("dataset", "mock_info"), [ - pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config)) + pytest.param(*builtin_dataset_mocks.load(name, **config), id=config_id(name, config)) for name in names for config in datasets.info(name)._configs ], @@ -89,7 +108,7 @@ def test_decoding(self, dataset, mock_info): f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." ) - @dataset_parametrization(decoder=DEFAULT_DECODER) + @dataset_parametrization() def test_no_vanilla_tensors(self, dataset, mock_info): vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -120,6 +139,15 @@ def scan(graph): else: raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.") + @dataset_parametrization() + def test_save_load(self, dataset, mock_info): + sample = next(iter(dataset)) + + with io.BytesIO() as buffer: + torch.save(sample, buffer) + buffer.seek(0) + assert_samples_equal(torch.load(buffer), sample) + class TestQMNIST: @pytest.mark.parametrize( diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py index ce50df123cc..33996db0cca 100644 --- a/test/test_prototype_datasets_api.py +++ b/test/test_prototype_datasets_api.py @@ -5,8 +5,8 @@ from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch -def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs) +def make_minimal_dataset_info(name="name", categories=None, **kwargs): + return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) class TestFrozenMapping: @@ -188,7 +188,7 @@ def resources(self, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass - def _make_datapipe(self, resource_dps, *, config, decoder): + def _make_datapipe(self, resource_dps, *, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass @@ -241,12 +241,3 @@ def test_resources(self, mocker): (call_args, _) = dataset._make_datapipe.call_args assert call_args[0][0] is sentinel - - def test_decoder(self): - dataset = self.DatasetMock() - - sentinel = object() - dataset.load("", decoder=sentinel) - - (_, call_kwargs) = dataset._make_datapipe.call_args - assert call_kwargs["decoder"] is sentinel diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 1945b5a5d9e..28840081fe7 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,7 +7,7 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error -from . import decoder, utils +from . import utils from ._home import home # Load this last, since some parts depend on the above being loaded first diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8b534c85413..8b7d6c90e54 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,12 +1,9 @@ -import io import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import raw, pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.utils._internal import add_suggestion from . import _builtin @@ -49,28 +46,14 @@ def info(name: str) -> DatasetInfo: return find(name).info -DEFAULT_DECODER = object() - -DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { - DatasetType.RAW: raw, - DatasetType.IMAGE: pil, -} - - def load( name: str, *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] skip_integrity_check: bool = False, split: str = "train", **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) - - if decoder is DEFAULT_DECODER: - decoder = DEFAULT_DECODER_MAP.get(dataset.info.type) - config = dataset.info.make_config(split=split, **options) root = os.path.join(home(), dataset.name) - - return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check) + return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index f4f8c44f8ee..b49701c0056 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,11 +1,8 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -18,7 +15,7 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + RawImage, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling from torchvision.prototype.features import Label, BoundingBox, Feature @@ -28,7 +25,6 @@ class Caltech101(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech101", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", ) @@ -81,32 +77,26 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: return category, id - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _decode_ann(self, data: BinaryIO) -> Dict[str, Any]: + ann = read_mat(data) + return dict( + bounding_box=BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy"), + contour=Feature(ann["obj_contour"].T), + ) + + def _prepare_sample( + self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]] ) -> Dict[str, Any]: key, (image_data, ann_data) = data category, _ = key image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - label = self.info.categories.index(category) - - image = decoder(image_buffer) if decoder else image_buffer - - ann = read_mat(ann_buffer) - bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy") - contour = Feature(ann["obj_contour"].T) - return dict( - category=category, - label=label, - image=image, + self._decode_ann(ann_buffer), + label=Label(self.info.categories.index(category), category=category), image_path=image_path, - bbox=bbox, - contour=contour, + image=RawImage.fromfile(image_buffer), ann_path=ann_path, ) @@ -115,7 +105,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps @@ -133,7 +122,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) @@ -145,7 +134,6 @@ class Caltech256(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech256", - type=DatasetType.IMAGE, homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", ) @@ -161,32 +149,28 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) return path.name != "RENAME2" - def _collate_and_decode_sample( - self, - data: Tuple[str, io.IOBase], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data dir_name = pathlib.Path(path).parent.name label_str, category = dir_name.split(".") - label = Label(int(label_str), category=category) - - return dict(label=label, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=RawImage.fromfile(buffer), + label=Label(int(label_str), category=category), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, self._is_not_rogue_file) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index b3c50c07943..593f8af6fdd 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,9 +1,7 @@ import csv import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -17,7 +15,7 @@ DatasetInfo, GDriveResource, OnlineResource, - DatasetType, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -26,6 +24,8 @@ hint_sharding, hint_shuffling, ) +from torchvision.prototype.features import BoundingBox, Feature, Label + csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -33,7 +33,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): def __init__( self, - datapipe: IterDataPipe[Tuple[Any, io.IOBase]], + datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, fieldnames: Optional[Sequence[str]] = None, ) -> None: @@ -65,7 +65,6 @@ class CelebA(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "celeba", - type=DatasetType.IMAGE, homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", ) @@ -90,7 +89,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", file_name="list_attr_celeba.txt", ) - bboxes = GDriveResource( + bounding_boxes = GDriveResource( "0B7EVK8r0v71pbThiMVRxWXZ4dU0", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", file_name="list_bbox_celeba.txt", @@ -100,7 +99,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", file_name="list_landmarks_align_celeba.txt", ) - return [splits, images, identities, attributes, bboxes, landmarks] + return [splits, images, identities, attributes, bounding_boxes, landmarks] _SPLIT_ID_TO_NAME = { "0": "train", @@ -111,38 +110,54 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split - def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]: - (image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data - return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks) - - def _collate_and_decode_sample( + def _decode_anns( self, - data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]], + data: Tuple[ + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + ], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + image_size: Tuple[int, int], + ) -> Dict[str, Any]: + (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = data + return dict( + identity=Label(int(identity["identity"])), + attributes={attr: value == "1" for attr, value in attributes.items()}, + bounding_box=BoundingBox( + [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], + format="xywh", + image_size=image_size, + ), + landmarks={ + landmark: Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) + for landmark in {key[:-2] for key in landmarks.keys()} + }, + ) + + def _prepare_sample( + self, + data: Tuple[ + Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]], + Tuple[ + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + ], + ], ) -> Dict[str, Any]: split_and_image_data, ann_data = data - _, _, image_data = split_and_image_data + _, (_, image_data) = split_and_image_data path, buffer = image_data - _, ann = ann_data - - image = decoder(buffer) if decoder else buffer - identity = int(ann["identity"]["identity"]) - attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} - bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) - landmarks = { - landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"]))) - for landmark in {key[:-2] for key in ann["landmarks"].keys()} - } + image = RawImage.fromfile(buffer) return dict( + self._decode_anns(ann_data, image_size=image.probe_image_size()), path=path, image=image, - identity=identity, - attributes=attributes, - bbox=bbox, - landmarks=landmarks, ) def _make_datapipe( @@ -150,9 +165,8 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps + splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) @@ -165,12 +179,11 @@ def _make_datapipe( for dp, fieldnames in ( (identities_dp, ("image_id", "identity")), (attributes_dp, None), - (bboxes_dp, None), + (bounding_boxes_dp, None), (landmarks_dp, None), ) ] ) - anns_dp = Mapper(anns_dp, self._collate_anns) dp = IterKeyZipper( splits_dp, @@ -180,5 +193,11 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + dp = IterKeyZipper( + dp, + anns_dp, + key_fn=getitem(0), + ref_key_fn=getitem(0, 0), + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 68147ba0f9e..37038485a2b 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -3,34 +3,28 @@ import io import pathlib import pickle -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Filter, Mapper, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, - image_buffer_from_array, path_comparator, hint_sharding, ) from torchvision.prototype.features import Label, Image -__all__ = ["Cifar10", "Cifar100"] - class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -50,38 +44,25 @@ class _CifarBase(Dataset): _CATEGORIES_KEY: str @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO], *, config: DatasetConfig) -> Optional[int]: pass def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: _, file = data return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) - def _collate_and_decode( - self, - data: Tuple[np.ndarray, int], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data - - image: Union[Image, io.BytesIO] - if decoder is raw: - image = Image(image_array) - else: - image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0))) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(category_idx, category=self.categories[category_idx]) - - return dict(image=image, label=label) + return dict( + image=Image(image_array), + label=Label(category_idx, category=self.categories[category_idx]), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, functools.partial(self._is_data_file, config=config)) @@ -89,7 +70,7 @@ def _make_datapipe( dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) @@ -110,7 +91,6 @@ def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool def _make_info(self) -> DatasetInfo: return DatasetInfo( "cifar10", - type=DatasetType.RAW, homepage="https://www.cs.toronto.edu/~kriz/cifar.html", ) @@ -135,7 +115,6 @@ def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool def _make_info(self) -> DatasetInfo: return DatasetInfo( "cifar100", - type=DatasetType.RAW, homepage="https://www.cs.toronto.edu/~kriz/cifar.html", valid_options=dict( split=("train", "test"), diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index e400a1db07d..7cf6cc3a3dd 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,9 +1,8 @@ import functools -import io import pathlib import re from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -22,7 +21,7 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -44,7 +43,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("pycocotools",), categories=categories, homepage="https://cocodataset.org/", @@ -150,26 +148,24 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _collate_and_decode_image( - self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: + def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - return dict(path=path, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=RawImage.fromfile(buffer), + ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], *, - annotations: Optional[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data - sample = self._collate_and_decode_image(image_data, decoder=decoder) - if annotations: - sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - + sample = self._prepare_image(image_data) + sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample def _make_datapipe( @@ -177,14 +173,13 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps if config.annotations is None: dp = hint_sharding(images_dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder)) + return Mapper(dp, self._prepare_image) meta_dp = Filter( meta_dp, @@ -230,9 +225,8 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder) - ) + + return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: config = self.default_config diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 9ea70296427..8a78de429f7 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,10 +1,7 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO -import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter from torchvision.prototype.datasets.utils import ( Dataset, @@ -12,7 +9,7 @@ DatasetInfo, OnlineResource, ManualDownloadResource, - DatasetType, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -40,7 +37,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("scipy",), categories=categories, homepage="https://www.image-net.org/", @@ -88,7 +84,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] category = self.wnid_to_category[wnid] @@ -101,41 +97,34 @@ def _val_test_image_key(self, data: Tuple[str, Any]) -> int: path = pathlib.Path(data[0]) return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] - def _collate_val_data( - self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] - ) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _prepare_val_data( + self, data: Tuple[Tuple[int, int], Tuple[str, BinaryIO]] + ) -> Tuple[Tuple[Label, str, str], Tuple[str, BinaryIO]]: label_data, image_data = data _, label = label_data category = self.categories[label] wnid = self.category_to_wnid[category] return (Label(label), category, wnid), image_data - def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]: + def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: return None, data - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, BinaryIO]], ) -> Dict[str, Any]: label_data, (path, buffer) = data sample = dict( path=path, - image=decoder(buffer) if decoder else buffer, + image=RawImage.fromfile(buffer), ) if label_data: sample.update(dict(zip(("label", "category", "wnid"), label_data))) - return sample def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, devkit_dp = resource_dps @@ -144,7 +133,7 @@ def _make_datapipe( dp = TarArchiveReader(images_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_train_data) + dp = Mapper(dp, self._prepare_train_data) elif config.split == "val": devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) devkit_dp = LineReader(devkit_dp, return_path=False) @@ -160,13 +149,12 @@ def _make_datapipe( ref_key_fn=self._val_test_image_key, buffer_size=INFINITE_BUFFER_SIZE, ) - dp = Mapper(dp, self._collate_val_data) + dp = Mapper(dp, self._prepare_val_data) else: # config.split == "test" dp = hint_sharding(images_dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_test_data) - - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + dp = Mapper(dp, self._prepare_test_data) + return Mapper(dp, self._prepare_sample) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 8f49f1ce72a..17552535869 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -1,10 +1,9 @@ import abc import functools -import io import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -13,17 +12,14 @@ Mapper, Zipper, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, - DatasetType, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import ( - image_buffer_from_array, Decompressor, INFINITE_BUFFER_SIZE, fromfile, @@ -98,31 +94,15 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: return None, None - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, label = data - - if decoder is raw: - image = Image(image) - else: - image_buffer = image_buffer_from_array(image.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)]) - - return dict(image=image, label=label) + return dict( + image=Image(image), + label=Label(label, dtype=torch.int64, category=self.info.categories[int(label)]), + ) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps start, stop = self.start_and_stop(config) @@ -136,14 +116,13 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._prepare_sample, config=config)) class MNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "mnist", - type=DatasetType.RAW, categories=10, homepage="http://yann.lecun.com/exdb/mnist", valid_options=dict( @@ -173,7 +152,6 @@ class FashionMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "fashionmnist", - type=DatasetType.RAW, categories=( "T-shirt/top", "Trouser", @@ -205,7 +183,6 @@ class KMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "kmnist", - type=DatasetType.RAW, categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], homepage="http://codh.rois.ac.jp/kmnist/index.html.en", valid_options=dict( @@ -226,7 +203,6 @@ class EMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "emnist", - type=DatasetType.RAW, categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", valid_options=dict( @@ -281,13 +257,7 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, @@ -300,14 +270,10 @@ def _collate_and_decode( image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._collate_and_decode(data, config=config, decoder=decoder) + return super()._prepare_sample(data, config=config) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( @@ -317,14 +283,13 @@ def _make_datapipe( drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder) + return super()._make_datapipe([images_dp, labels_dp], config=config) class QMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "qmnist", - type=DatasetType.RAW, categories=10, homepage="https://github.com/facebookresearch/qmnist", valid_options=dict( @@ -366,16 +331,10 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) + sample = super()._prepare_sample((image, label), config=config) sample.update( dict( diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 82fdb2adf8b..a5df4356ba6 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,11 +1,8 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +17,7 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -31,19 +28,17 @@ hint_sharding, hint_shuffling, ) +from torchvision.prototype.features import Feature class SBD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "sbd", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", valid_options=dict( split=("train", "val", "train_noval"), - boundaries=(True, False), - segmentation=(False, True), ), ) @@ -74,50 +69,25 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _decode_ann( - self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - raw_anns = data["GTcls"][0] - raw_boundaries = raw_anns["Boundaries"][0] - raw_segmentation = raw_anns["Segmentation"][0] - - # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries = ( - torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries])) - if decode_boundaries - else None + def _decode_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + raw_anns = read_mat(buffer)["GTcls"][0] + return dict( + # the boundaries are stored in sparse CSC format, which is not supported by PyTorch + boundaries=Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_anns["Boundaries"][0]])), + segmentation=Feature(raw_anns["Segmentation"][0]), ) - segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None - - return boundaries, segmentation - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.boundaries or config.segmentation: - boundaries, segmentation = self._decode_ann( - read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation - ) - else: - boundaries = segmentation = None - return dict( + self._decode_ann(ann_buffer), image_path=image_path, - image=image, + image=RawImage.fromfile(image_buffer), ann_path=ann_path, - boundaries=boundaries, - segmentation=segmentation, ) def _make_datapipe( @@ -125,7 +95,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps @@ -137,9 +106,9 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) - if config.split == "train_noval": split_dp = extra_split_dp + split_dp = LineReader(split_dp, decode=True) split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -153,7 +122,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index 9df12c98b9b..85734ad1ec8 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,6 +1,4 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch from torchdata.datapipes.iter import ( @@ -8,23 +6,21 @@ Mapper, CSVParser, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) -from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, Label class SEMEION(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "semeion", - type=DatasetType.RAW, categories=10, homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", ) @@ -36,35 +32,27 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [data] - def _collate_and_decode_sample( - self, - data: Tuple[str, ...], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16) label_data = [int(label) for label in data[256:] if label] - if decoder is raw: - image = image_data.unsqueeze(0) - else: - image_buffer = image_buffer_from_array(image_data.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) - category = self.info.categories[label] - return dict(image=image, label=label, category=category) + label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) + return dict( + image=Image(image_data.unsqueeze(0)), + label=Label( + label_idx, + category=self.info.categories[label_idx], + ), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) - return dp + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/voc.categories b/torchvision/prototype/datasets/_builtin/voc.categories new file mode 100644 index 00000000000..8420ab35ede --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/voc.categories @@ -0,0 +1,20 @@ +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 66905fac3bd..c2065ebe358 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,10 +1,8 @@ import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast from xml.etree import ElementTree -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +18,7 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, + RawImage, ) from torchvision.prototype.datasets.utils._internal import ( path_accessor, @@ -30,15 +28,13 @@ hint_sharding, hint_shuffling, ) - -HERE = pathlib.Path(__file__).parent +from torchvision.prototype.features import BoundingBox, Label class VOC(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "voc", - type=DatasetType.IMAGE, homepage="http://host.robots.ox.ac.uk/pascal/VOC/", valid_options=dict( split=("train", "val", "test"), @@ -83,40 +79,51 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None - def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor: - result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type] - objects = result["annotation"]["object"] - bboxes = [obj["bndbox"] for obj in objects] - bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes] - return torch.tensor(bboxes) + def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"]) + + def _decode_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + anns = self._parse_detection_ann(buffer) + instances = anns["object"] + return dict( + bounding_boxes=BoundingBox( + [ + [int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for instance in instances + ], + format="xyxy", + image_size=tuple(int(anns["size"][dim]) for dim in ("height", "width")), + ), + labels=[ + Label(self.info.categories.index(instance["name"]), category=instance["name"]) for instance in instances + ], + ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.task == "detection": - ann = self._decode_detection_ann(ann_buffer) - else: # config.task == "segmentation": - ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment] - - return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann) + return dict( + self._decode_detection_ann(ann_buffer) + if config.task == "detection" + else dict(segmentation=RawImage.fromfile(ann_buffer)), + image_path=image_path, + image=RawImage.fromfile(image_buffer), + ann_path=ann_path, + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( @@ -142,4 +149,17 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._prepare_sample, config=config)) + + def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: + return self._classify_archive(data, config=config) == 2 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.info.make_config(task="detection") + + resource = self.resources(config)[0] + dp = resource.load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + dp = Mapper(dp, self._parse_detection_ann, input_col=1) + + return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index efffaa80f99..7c8a22a829b 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -1,16 +1,14 @@ import functools -import io import os import os.path import pathlib -from typing import Callable, Optional, Collection -from typing import Union, Tuple, List, Dict, Any +from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any -import torch from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter -from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding +from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Filter +from torchvision.prototype.datasets.utils import RawData, RawImage +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Label __all__ = ["from_data_folder", "from_image_folder"] @@ -21,29 +19,24 @@ def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") -def _collate_and_decode_data( - data: Tuple[str, io.IOBase], +def _prepare_sample( + data: Tuple[str, BinaryIO], *, root: pathlib.Path, categories: List[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: path, buffer = data - data = decoder(buffer) if decoder else buffer category = pathlib.Path(path).relative_to(root).parts[0] - label = torch.tensor(categories.index(category)) return dict( path=path, - data=data, - label=label, - category=category, + data=RawData.fromfile(buffer), + label=Label(categories.index(category), category=category), ) def from_data_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, valid_extensions: Optional[Collection[str]] = None, recursive: bool = True, ) -> Tuple[IterDataPipe, List[str]]: @@ -53,26 +46,22 @@ def from_data_folder( dp = FileLister(str(root), recursive=recursive, masks=masks) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp = hint_sharding(dp) - dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + dp = hint_shuffling(dp) dp = FileLoader(dp) - return ( - Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)), - categories, - ) + return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = sample.pop("data") + sample["image"] = RawImage(sample.pop("data").data) return sample def from_image_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) + dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs) return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py deleted file mode 100644 index 530a357f239..00000000000 --- a/torchvision/prototype/datasets/decoder.py +++ /dev/null @@ -1,16 +0,0 @@ -import io - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.transforms.functional import pil_to_tensor - -__all__ = ["raw", "pil"] - - -def raw(buffer: io.IOBase) -> torch.Tensor: - raise RuntimeError("This is just a sentinel and should never be called.") - - -def pil(buffer: io.IOBase) -> features.Image: - return features.Image(pil_to_tensor(PIL.Image.open(buffer))) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 92bcffc0cdb..6b3d07660a1 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,12 @@ -from . import _internal -from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset +from . import _internal # usort: skip +from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._decoder import ( + decode_images, + decode_sample, + decode_image_with_pil, + RawImage, + RawData, + ReadOnlyTensorBuffer, +) from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 242f9c961c0..5ecf9a04454 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,14 +1,11 @@ import abc import csv -import enum import importlib -import io import itertools import os import pathlib -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple +from typing import Any, Dict, List, Optional, Sequence, Union, Tuple -import torch from torch.utils.data import IterDataPipe from torchvision.prototype.utils._internal import FrozenBunch, make_repr from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str @@ -18,11 +15,6 @@ from ._resource import OnlineResource -class DatasetType(enum.Enum): - RAW = enum.auto() - IMAGE = enum.auto() - - class DatasetConfig(FrozenBunch): pass @@ -32,7 +24,6 @@ def __init__( self, name: str, *, - type: Union[str, DatasetType], dependencies: Sequence[str] = (), categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, @@ -42,7 +33,6 @@ def __init__( extra: Optional[Dict[str, Any]] = None, ) -> None: self.name = name.lower() - self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.dependecies = dependencies @@ -165,7 +155,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: pass @@ -177,7 +166,6 @@ def load( root: Union[str, pathlib.Path], *, config: Optional[DatasetConfig] = None, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, skip_integrity_check: bool = False, ) -> IterDataPipe[Dict[str, Any]]: if not config: @@ -192,7 +180,7 @@ def load( resource_dps = [ resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) ] - return self._make_datapipe(resource_dps, config=config, decoder=decoder) + return self._make_datapipe(resource_dps, config=config) def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError diff --git a/torchvision/prototype/datasets/utils/_decoder.py b/torchvision/prototype/datasets/utils/_decoder.py new file mode 100644 index 00000000000..5a28e13e7b6 --- /dev/null +++ b/torchvision/prototype/datasets/utils/_decoder.py @@ -0,0 +1,80 @@ +import collections.abc +import sys +from typing import Any, Dict, Callable, Type, TypeVar, cast, BinaryIO, Tuple + +import PIL.Image +import torch +from torch._C import _TensorBase +from torchvision.prototype import features +from torchvision.transforms.functional import pil_to_tensor + +from ._internal import ReadOnlyTensorBuffer, fromfile + +D = TypeVar("D", bound="RawData") + + +class RawData(torch.Tensor): + def __new__(cls, data: torch.Tensor) -> "RawData": + # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? + return cast( + RawData, + torch.Tensor._make_subclass( + cast(_TensorBase, cls), + data, + False, # requires_grad + ), + ) + + @classmethod + def fromfile(cls: Type[D], file: BinaryIO) -> D: + return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) + + +class RawImage(RawData): + def probe_image_size(self) -> Tuple[int, int]: + if not hasattr(self, "_image_size"): + image = PIL.Image.open(ReadOnlyTensorBuffer(self)) + self._image_size = image.height, image.width + + return self._image_size + + +def decode_image_with_pil(raw_image: RawImage) -> Dict[str, Any]: + return dict(image=features.Image(pil_to_tensor(PIL.Image.open(ReadOnlyTensorBuffer(raw_image))))) + + +def decode_sample( + sample: Any, *, decoder_map: Dict[Type[D], Callable[[D], Dict[str, Any]]], inline_decoded: bool = True +) -> Any: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if isinstance(sample, collections.abc.Sequence) and not isinstance(sample, str): + return [decode_sample(item, decoder_map=decoder_map, inline_decoded=inline_decoded) for item in sample] + elif isinstance(sample, collections.abc.Mapping): + decoded_sample = {} + for name, item in sample.items(): + decoded_item = decode_sample(item, decoder_map=decoder_map, inline_decoded=inline_decoded) + if inline_decoded and isinstance(item, RawData): + decoded_sample.update(decoded_item) + else: + decoded_sample[name] = decoded_item + return decoded_sample + else: + sample_type = type(sample) + if not issubclass(sample_type, RawData): + return sample + + try: + return decoder_map[cast(Type[D], sample_type)](cast(D, sample)) + except KeyError as error: + raise TypeError(f"Unknown type {sample_type}") from error + + +def decode_images(sample: Any, *, inline_decoded: bool = True) -> Any: + return decode_sample( + sample, + decoder_map={ + RawImage: decode_image_with_pil, + }, + inline_decoded=inline_decoded, + ) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 824594dd28e..c52255e602e 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -26,7 +26,6 @@ from typing import cast import numpy as np -import PIL.Image import torch import torch.distributed as dist import torch.utils.data @@ -38,7 +37,6 @@ "INFINITE_BUFFER_SIZE", "BUILTIN_DIR", "read_mat", - "image_buffer_from_array", "SequenceIterator", "MappingIterator", "Enumerator", @@ -49,6 +47,8 @@ "fromfile", "read_flo", "hint_sharding", + "hint_shuffling", + "ReadOnlyTensorBuffer", ] K = TypeVar("K") @@ -60,7 +60,7 @@ BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" -def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: +def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any: try: import scipy.io as sio except ImportError as error: @@ -72,14 +72,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: return sio.loadmat(buffer, **kwargs) -def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: - image = PIL.Image.fromarray(array) - buffer = io.BytesIO() - image.save(buffer, format=format) - buffer.seek(0) - return buffer - - class SequenceIterator(IterDataPipe[D]): def __init__(self, datapipe: IterDataPipe[Sequence[D]]): self.datapipe = datapipe @@ -146,17 +138,17 @@ class CompressionType(enum.Enum): LZMA = "lzma" -class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): +class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]): types = CompressionType - _DECOMPRESSORS = { - types.GZIP: lambda file: gzip.GzipFile(fileobj=file), - types.LZMA: lambda file: lzma.LZMAFile(file), + _DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = { + types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)), + types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)), } def __init__( self, - datapipe: IterDataPipe[Tuple[str, io.IOBase]], + datapipe: IterDataPipe[Tuple[str, BinaryIO]], *, type: Optional[Union[str, CompressionType]] = None, ) -> None: @@ -178,7 +170,7 @@ def _detect_compression_type(self, path: str) -> CompressionType: else: raise RuntimeError("FIXME") - def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: + def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]: for path, file in self.datapipe: type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[type] @@ -311,7 +303,7 @@ def fromfile( buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] # Reading from the memoryview does not advance the file cursor, so we have to do it manually. file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) - except (PermissionError, io.UnsupportedOperation): + except (AttributeError, PermissionError, io.UnsupportedOperation): buffer = _read_mutable_buffer_fallback(file, count, item_size) else: # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state @@ -339,3 +331,32 @@ def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE) + + +class ReadOnlyTensorBuffer: + def __init__(self, tensor: torch.Tensor) -> None: + self._memory = memoryview(tensor.numpy()) + self._cursor: int = 0 + + def tell(self) -> int: + return self._cursor + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_SET: + self._cursor = offset + elif whence == io.SEEK_CUR: + self._cursor += offset + pass + elif whence == io.SEEK_END: + self._cursor = len(self._memory) + offset + else: + raise ValueError( + f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, " + f"but got {repr(whence)} instead" + ) + return self.tell() + + def read(self, size: int = -1) -> bytes: + cursor = self.tell() + offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) + return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()