diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 99de85d5c73..f2218568894 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -431,50 +431,52 @@ def caltech256(info, root, config): @register_mock def imagenet(info, root, config): - wnids = tuple(info.extra.wnid_to_category.keys()) - if config.split == "train": - images_root = root / "ILSVRC2012_img_train" + from scipy.io import savemat + categories = info.categories + wnids = [info.extra.category_to_wnid[category] for category in categories] + if config.split == "train": num_samples = len(wnids) + archive_name = "ILSVRC2012_img_train.tar" + files = [] for wnid in wnids: - files = create_image_folder( - root=images_root, + create_image_folder( + root=root, name=wnid, file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG", num_examples=1, ) - make_tar(images_root, f"{wnid}.tar", files[0].parent) + files.append(make_tar(root, f"{wnid}.tar")) elif config.split == "val": num_samples = 3 - files = create_image_folder( - root=root, - name="ILSVRC2012_img_val", - file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG", - num_examples=num_samples, - ) - images_root = files[0].parent - else: # config.split == "test" - images_root = root / "ILSVRC2012_img_test_v10102019" + archive_name = "ILSVRC2012_img_val.tar" + files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - num_samples = 3 + devkit_root = root / "ILSVRC2012_devkit_t12" + data_root = devkit_root / "data" + data_root.mkdir(parents=True) - create_image_folder( - root=images_root, - name="test", - file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG", - num_examples=num_samples, - ) - make_tar(root, f"{images_root.name}.tar", images_root) + with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: + for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): + file.write(f"{label}\n") + + num_children = 0 + synsets = [ + (idx, wnid, category, "", num_children, [], 0, 0) + for idx, (category, wnid) in enumerate(zip(categories, wnids), 1) + ] + num_children = 1 + synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) + savemat(data_root / "meta.mat", dict(synsets=synsets)) + + make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") + else: # config.split == "test" + num_samples = 5 + archive_name = "ILSVRC2012_img_test_v10102019.tar" + files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - devkit_root = root / "ILSVRC2012_devkit_t12" - devkit_root.mkdir() - data_root = devkit_root / "data" - data_root.mkdir() - with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): - file.write(f"{label}\n") - make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz") + make_tar(root, archive_name, *files) return num_samples @@ -666,14 +668,15 @@ def sbd(info, root, config): @register_mock def semeion(info, root, config): num_samples = 3 + num_categories = len(info.categories) images = torch.rand(num_samples, 256) - labels = one_hot(torch.randint(len(info.categories), size=(num_samples,))) + labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) with open(root / "semeion.data", "w") as fh: for image, one_hot_label in zip(images, labels): image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image]) labels_columns = " ".join([str(label.item()) for label in one_hot_label]) - fh.write(f"{image_columns} {labels_columns}\n") + fh.write(f"{image_columns} {labels_columns} \n") return num_samples @@ -728,32 +731,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples): def _make_detection_ann_file(cls, root, name): def add_child(parent, name, text=None): child = ET.SubElement(parent, name) - child.text = text + child.text = str(text) return child def add_name(obj, name="dog"): add_child(obj, "name", name) - return name - def add_bndbox(obj, bndbox=None): - if bndbox is None: - bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"} + def add_size(obj): + obj = add_child(obj, "size") + size = {"width": 0, "height": 0, "depth": 3} + for name, text in size.items(): + add_child(obj, name, text) + def add_bndbox(obj): obj = add_child(obj, "bndbox") + bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4} for name, text in bndbox.items(): add_child(obj, name, text) - return bndbox - annotation = ET.Element("annotation") + add_size(annotation) obj = add_child(annotation, "object") - data = dict(name=add_name(obj), bndbox=add_bndbox(obj)) + add_name(obj) + add_bndbox(obj) with open(root / name, "wb") as fh: fh.write(ET.tostring(annotation)) - return data - @classmethod def generate(cls, root, *, year, trainval): archive_folder = root diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 7b6a22600e1..8319b2bab55 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -1,9 +1,11 @@ +import functools import io from pathlib import Path import pytest import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS +from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler @@ -11,6 +13,11 @@ from torchvision.prototype.utils._internal import sequence_to_str +assert_samples_equal = functools.partial( + assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True +) + + @pytest.fixture def test_home(mocker, tmp_path): mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) @@ -92,6 +99,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." ) + @pytest.mark.xfail @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) @@ -137,6 +145,17 @@ def scan(graph): if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_save_load(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + sample = next(iter(dataset)) + + with io.BytesIO() as buffer: + torch.save(sample, buffer) + buffer.seek(0) + assert_samples_equal(torch.load(buffer), sample) + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: @@ -171,5 +190,5 @@ def test_label_matches_path(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: - label_from_path = int(Path(sample["image_path"]).parent.name) + label_from_path = int(Path(sample["path"]).parent.name) assert sample["label"] == label_from_path diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py index 3a6126f4990..70a2707d050 100644 --- a/test/test_prototype_datasets_api.py +++ b/test/test_prototype_datasets_api.py @@ -5,8 +5,8 @@ from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch -def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs) +def make_minimal_dataset_info(name="name", categories=None, **kwargs): + return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) class TestFrozenMapping: @@ -176,7 +176,7 @@ def resources(self, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass - def _make_datapipe(self, resource_dps, *, config, decoder): + def _make_datapipe(self, resource_dps, *, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass @@ -229,12 +229,3 @@ def test_resources(self, mocker): (call_args, _) = dataset._make_datapipe.call_args assert call_args[0][0] is sentinel - - def test_decoder(self): - dataset = self.DatasetMock() - - sentinel = object() - dataset.load("", decoder=sentinel) - - (_, call_kwargs) = dataset._make_datapipe.call_args - assert call_kwargs["decoder"] is sentinel diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py deleted file mode 100644 index 2bcd6692e81..00000000000 --- a/test/test_prototype_transforms.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from torchvision.prototype import transforms, features -from torchvision.prototype.utils._internal import sequence_to_str - - -FEATURE_TYPES = { - feature_type - for name, feature_type in features.__dict__.items() - if not name.startswith("_") - and isinstance(feature_type, type) - and issubclass(feature_type, features.Feature) - and feature_type is not features.Feature -} - -TRANSFORM_TYPES = tuple( - transform_type - for name, transform_type in transforms.__dict__.items() - if not name.startswith("_") - and isinstance(transform_type, type) - and issubclass(transform_type, transforms.Transform) - and transform_type is not transforms.Transform -) - - -def test_feature_type_support(): - missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES) - if missing_feature_types: - names = sorted([feature_type.__name__ for feature_type in missing_feature_types]) - raise AssertionError( - f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at " - f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. " - f"Please add it/them to the collection." - ) - - -@pytest.mark.parametrize( - "transform_type", - [transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity], - ids=lambda transform_type: transform_type.__name__, -) -def test_feature_no_op_coverage(transform_type): - unsupported_features = ( - FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES) - ) - if unsupported_features: - names = sorted([feature_type.__name__ for feature_type in unsupported_features]) - raise AssertionError( - f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as " - f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, " - f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection." - ) - - -def test_non_feature_no_op(): - class TestTransform(transforms.Transform): - @staticmethod - def image(input): - return input - - no_op_sample = dict(int=0, float=0.0, bool=False, str="str") - assert TestTransform()(no_op_sample) == no_op_sample diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 1945b5a5d9e..28840081fe7 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,7 +7,7 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error -from . import decoder, utils +from . import utils from ._home import home # Load this last, since some parts depend on the above being loaded first diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index e9240eb46ce..63e2bb8cf6f 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,12 +1,9 @@ -import io import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import raw, pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.utils._internal import add_suggestion from . import _builtin @@ -49,27 +46,15 @@ def info(name: str) -> DatasetInfo: return find(name).info -DEFAULT_DECODER = object() - -DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { - DatasetType.RAW: raw, - DatasetType.IMAGE: pil, -} - - def load( name: str, *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] skip_integrity_check: bool = False, **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) - if decoder is DEFAULT_DECODER: - decoder = DEFAULT_DECODER_MAP.get(dataset.info.type) - config = dataset.info.make_config(**options) root = os.path.join(home(), dataset.name) - return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check) + return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md index 8ee6e8e5a66..fbe84856aeb 100644 --- a/torchvision/prototype/datasets/_builtin/README.md +++ b/torchvision/prototype/datasets/_builtin/README.md @@ -19,10 +19,8 @@ that module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be discussed in detail below: ```python -import io -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource @@ -34,11 +32,7 @@ class MyDataset(Dataset): ... def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, ) -> IterDataPipe[Dict[str, Any]]: ... ``` @@ -49,10 +43,6 @@ The `DatasetInfo` carries static information about the dataset. There are two required fields: - `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain lowercase characters. -- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select - the default decoder in case the user doesn't pass one. There are currently - only two options: `IMAGE` and `RAW` ([see - below](what-is-the-datasettyperaw-and-when-do-i-use-it) for details). There are more optional parameters that can be passed: @@ -105,7 +95,7 @@ def sha256sum(path, chunk_size=1024 * 1024): print(checksum.hexdigest()) ``` -### `_make_datapipe(resource_dps, *, config, decoder)` +### `_make_datapipe(resource_dps, *, config)` This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared to the current stable datasets is @@ -178,28 +168,6 @@ contains. You can also do that with `resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these datapipes and return the appropriate dictionary format. -### What is the `DatasetType.RAW` and when do I use it? - -`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values, -rather than encoded image files such as `.jpg` or `.png`. This is usually only -the case for small datasets, since it requires a lot more disk space. The -default decoder `datasets.decoder.raw` is only a sentinel and should not be -called directly. The decoding should look something like - -```python -from torchvision.prototype.datasets.decoder import raw - -image = ... - -if decoder is raw: - image = Image(image) -else: - image_buffer = image_buffer_from_raw(image) - image = decoder(image_buffer) if decoder else image_buffer -``` - -For examples, have a look at the MNIST, CIFAR, or SEMEION datasets. - ### How do I handle a dataset that defines many categories? As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index be19b7c240f..c170f102dc5 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,11 +1,8 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -18,17 +15,15 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling -from torchvision.prototype.features import Label, BoundingBox, Feature +from torchvision.prototype.features import Label, BoundingBox, Feature, EncodedImage class Caltech101(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech101", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", ) @@ -81,33 +76,26 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: return category, id - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _prepare_sample( + self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]] ) -> Dict[str, Any]: key, (image_data, ann_data) = data category, _ = key image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - label = self.info.categories.index(category) - - image = decoder(image_buffer) if decoder else image_buffer - + image = EncodedImage.from_file(image_buffer) ann = read_mat(ann_buffer) - bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy") - contour = Feature(ann["obj_contour"].T) return dict( - category=category, - label=label, - image=image, + label=Label.from_category(category, categories=self.categories), image_path=image_path, - bbox=bbox, - contour=contour, + image=image, ann_path=ann_path, + bounding_box=BoundingBox( + ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size + ), + contour=Feature(ann["obj_contour"].T), ) def _make_datapipe( @@ -115,7 +103,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps @@ -133,7 +120,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: resources = self.resources(self.default_config) @@ -148,7 +135,6 @@ class Caltech256(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech256", - type=DatasetType.IMAGE, homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", ) @@ -164,32 +150,26 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) return path.name != "RENAME2" - def _collate_and_decode_sample( - self, - data: Tuple[str, io.IOBase], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - dir_name = pathlib.Path(path).parent.name - label_str, category = dir_name.split(".") - label = Label(int(label_str), category=category) - - return dict(label=label, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=EncodedImage.from_file(buffer), + label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, self._is_not_rogue_file) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: resources = self.resources(self.default_config) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index b59959b49f1..e27b359c11a 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,9 +1,7 @@ import csv import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -17,7 +15,6 @@ DatasetInfo, GDriveResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -26,7 +23,8 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import Feature, Label, BoundingBox +from torchvision.prototype.features import EncodedImage, Feature, Label, BoundingBox + csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -34,7 +32,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): def __init__( self, - datapipe: IterDataPipe[Tuple[Any, io.IOBase]], + datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, fieldnames: Optional[Sequence[str]] = None, ) -> None: @@ -66,7 +64,6 @@ class CelebA(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "celeba", - type=DatasetType.IMAGE, homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", valid_options=dict(split=("train", "val", "test")), ) @@ -92,7 +89,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", file_name="list_attr_celeba.txt", ) - bboxes = GDriveResource( + bounding_boxes = GDriveResource( "0B7EVK8r0v71pbThiMVRxWXZ4dU0", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", file_name="list_bbox_celeba.txt", @@ -102,7 +99,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", file_name="list_landmarks_align_celeba.txt", ) - return [splits, images, identities, attributes, bboxes, landmarks] + return [splits, images, identities, attributes, bounding_boxes, landmarks] _SPLIT_ID_TO_NAME = { "0": "train", @@ -113,38 +110,39 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split - def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]: - (image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data - return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks) - - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[str, Tuple[Tuple[str, Dict[str, Any]], Tuple[str, io.IOBase]]], Tuple[str, Dict[str, Any]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + data: Tuple[ + Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]], + Tuple[ + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + ], + ], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, (_, image_data) = split_and_image_data path, buffer = image_data - _, ann = ann_data - - image = decoder(buffer) if decoder else buffer - identity = Label(int(ann["identity"]["identity"])) - attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} - bbox = BoundingBox([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) - landmarks = { - landmark: Feature((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"]))) - for landmark in {key[:-2] for key in ann["landmarks"].keys()} - } + image = EncodedImage.from_file(buffer) + (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data return dict( path=path, image=image, - identity=identity, - attributes=attributes, - bbox=bbox, - landmarks=landmarks, + identity=Label(int(identity["identity"])), + attributes={attr: value == "1" for attr, value in attributes.items()}, + bounding_box=BoundingBox( + [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], + format="xywh", + image_size=image.image_size, + ), + landmarks={ + landmark: Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) + for landmark in {key[:-2] for key in landmarks.keys()} + }, ) def _make_datapipe( @@ -152,9 +150,8 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps + splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) @@ -167,12 +164,11 @@ def _make_datapipe( for dp, fieldnames in ( (identities_dp, ("image_id", "identity")), (attributes_dp, None), - (bboxes_dp, None), + (bounding_boxes_dp, None), (landmarks_dp, None), ) ] ) - anns_dp = Mapper(anns_dp, self._collate_anns) dp = IterKeyZipper( splits_dp, @@ -182,5 +178,11 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + dp = IterKeyZipper( + dp, + anns_dp, + key_fn=getitem(0), + ref_key_fn=getitem(0, 0), + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 6ac2de3c9e6..f15ed9e9782 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -3,34 +3,28 @@ import io import pathlib import pickle -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Filter, Mapper, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, - image_buffer_from_array, path_comparator, hint_sharding, ) from torchvision.prototype.features import Label, Image -__all__ = ["Cifar10", "Cifar100"] - class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -52,13 +46,12 @@ class _CifarBase(Dataset): _CATEGORIES_KEY: str @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, io.IOBase], *, split: str) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]: pass def _make_info(self) -> DatasetInfo: return DatasetInfo( type(self).__name__.lower(), - type=DatasetType.RAW, homepage="https://www.cs.toronto.edu/~kriz/cifar.html", valid_options=dict(split=("train", "test")), ) @@ -75,31 +68,18 @@ def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: _, file = data return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) - def _collate_and_decode( - self, - data: Tuple[np.ndarray, int], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data - - image: Union[Image, io.BytesIO] - if decoder is raw: - image = Image(image_array) - else: - image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0))) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(category_idx, category=self.categories[category_idx]) - - return dict(image=image, label=label) + return dict( + image=Image(image_array), + label=Label(category_idx, categories=self.categories), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) @@ -107,7 +87,7 @@ def _make_datapipe( dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: resources = self.resources(self.default_config) diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 447c1b5190d..af5a49a9822 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,9 +1,6 @@ -import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher from torchvision.prototype.datasets.utils import ( Dataset, @@ -11,7 +8,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -21,14 +17,13 @@ path_accessor, getitem, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage class CLEVR(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "clevr", - type=DatasetType.IMAGE, homepage="https://cs.stanford.edu/people/jcjohns/clevr/", valid_options=dict(split=("train", "val", "test")), ) @@ -53,21 +48,16 @@ def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool: key, _ = data return key == "scenes" - def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]: + def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]: return data, None - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]: image_data, scenes_data = data path, buffer = image_data return dict( path=path, - image=decoder(buffer) if decoder else buffer, + image=EncodedImage.from_file(buffer), label=Label(len(scenes_data["objects"])) if scenes_data else None, ) @@ -76,7 +66,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, scenes_dp = Demultiplexer( @@ -107,4 +96,4 @@ def _make_datapipe( else: dp = Mapper(images_dp, self._add_empty_anns) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 6fde966402c..93d9d4b627c 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,9 +1,8 @@ import functools -import io import pathlib import re from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -22,7 +21,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -33,7 +31,7 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import BoundingBox, Label, Feature +from torchvision.prototype.features import BoundingBox, Label, Feature, EncodedImage from torchvision.prototype.utils._internal import FrozenMapping @@ -44,7 +42,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("pycocotools",), categories=categories, homepage="https://cocodataset.org/", @@ -96,7 +93,6 @@ def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: image_size = (image_meta["height"], image_meta["width"]) labels = [ann["category_id"] for ann in anns] - categories = [self.info.categories[label] for label in labels] return dict( # TODO: create a segmentation feature segmentations=Feature( @@ -114,9 +110,10 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st format="xywh", image_size=image_size, ), - labels=Label(labels), - categories=categories, - super_categories=[self.info.extra.category_to_super_category[category] for category in categories], + labels=Label(labels, categories=self.categories), + super_categories=[ + self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels + ], ann_ids=[ann["id"] for ann in anns], ) @@ -150,26 +147,24 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _collate_and_decode_image( - self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: + def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - return dict(path=path, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=EncodedImage.from_file(buffer), + ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], *, - annotations: Optional[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data - sample = self._collate_and_decode_image(image_data, decoder=decoder) - if annotations: - sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - + sample = self._prepare_image(image_data) + sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample def _make_datapipe( @@ -177,14 +172,13 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps if config.annotations is None: dp = hint_sharding(images_dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder)) + return Mapper(dp, self._prepare_image) meta_dp = Filter( meta_dp, @@ -230,9 +224,8 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder) - ) + + return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: config = self.default_config diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index facd909f468..008e4fd06b1 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,10 +1,8 @@ import csv import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -21,7 +19,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -32,7 +29,7 @@ path_comparator, path_accessor, ) -from torchvision.prototype.features import Label, BoundingBox, Feature +from torchvision.prototype.features import Label, BoundingBox, Feature, EncodedImage csv.register_dialect("cub200", delimiter=" ") @@ -41,7 +38,6 @@ class CUB200(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "cub200", - type=DatasetType.IMAGE, homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html", dependencies=("scipy",), valid_options=dict( @@ -105,58 +101,55 @@ def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: path = pathlib.Path(data[0]) return path.with_suffix(".jpg").name - def _2011_load_ann( - self, - data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _2011_prepare_ann( + self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int] ) -> Dict[str, Any]: _, (bounding_box_data, segmentation_data) = data segmentation_path, segmentation_buffer = segmentation_data return dict( - bounding_box=BoundingBox([float(part) for part in bounding_box_data[1:]], format="xywh"), + bounding_box=BoundingBox( + [float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size + ), segmentation_path=segmentation_path, - segmentation=Feature(decoder(segmentation_buffer)) if decoder else segmentation_buffer, + segmentation=EncodedImage.from_file(segmentation_buffer), ) def _2010_split_key(self, data: str) -> str: return data.rsplit("/", maxsplit=1)[1] - def _2010_anns_key(self, data: Tuple[str, io.IOBase]) -> Tuple[str, Tuple[str, io.IOBase]]: + def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) return path.with_suffix(".jpg").name, data - def _2010_load_ann( - self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: + def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]: _, (path, buffer) = data content = read_mat(buffer) return dict( ann_path=path, bounding_box=BoundingBox( - [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], format="xyxy" + [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], + format="xyxy", + image_size=image_size, ), segmentation=Feature(content["seg"]), ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[str, Tuple[str, io.IOBase]], Any], + data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any], *, - year: str, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]], ) -> Dict[str, Any]: data, anns_data = data _, image_data = data path, buffer = image_data - dir_name = pathlib.Path(path).parent.name - label_str, category = dir_name.split(".") + image = EncodedImage.from_file(buffer) return dict( - (self._2011_load_ann if year == "2011" else self._2010_load_ann)(anns_data, decoder=decoder), - image=decoder(buffer) if decoder else buffer, - label=Label(int(label_str), category=category), + prepare_ann_fn(anns_data, image.image_size), + image=image, + label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories), ) def _make_datapipe( @@ -164,8 +157,8 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: + prepare_ann_fn: Callable if config.year == "2011": archive_dp, segmentations_dp = resource_dps images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( @@ -193,6 +186,8 @@ def _make_datapipe( keep_key=True, buffer_size=INFINITE_BUFFER_SIZE, ) + + prepare_ann_fn = self._2011_prepare_ann else: # config.year == "2010" split_dp, images_dp, anns_dp = resource_dps @@ -202,6 +197,8 @@ def _make_datapipe( anns_dp = Mapper(anns_dp, self._2010_anns_key) + prepare_ann_fn = self._2010_prepare_ann + split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -218,7 +215,7 @@ def _make_datapipe( getitem(0), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, year=config.year, decoder=decoder)) + return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) def _generate_categories(self, root: pathlib.Path) -> List[str]: config = self.info.make_config(year="2011") diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 36990e8a21d..18966e087a6 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,10 +1,7 @@ import enum -import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -21,7 +18,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -29,7 +25,7 @@ path_comparator, getitem, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage class DTDDemux(enum.IntEnum): @@ -42,7 +38,6 @@ class DTD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "dtd", - type=DatasetType.IMAGE, homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", valid_options=dict( split=("train", "test", "val"), @@ -74,12 +69,7 @@ def _image_key_fn(self, data: Tuple[str, Any]) -> str: path = pathlib.Path(data[0]) return str(path.relative_to(path.parents[1])) - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: (_, joint_categories_data), image_data = data _, *joint_categories = joint_categories_data path, buffer = image_data @@ -88,9 +78,9 @@ def _collate_and_decode_sample( return dict( joint_categories={category for category in joint_categories if category}, - label=Label(self.info.categories.index(category), category=category), + label=Label.from_category(category, categories=self.categories), path=path, - image=decoder(buffer) if decoder else buffer, + image=EncodedImage.from_file(buffer), ) def _make_datapipe( @@ -98,7 +88,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] @@ -127,7 +116,7 @@ def _make_datapipe( ref_key_fn=self._image_key_fn, buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index 2d9bd713990..47d2ddc9acc 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,22 +1,17 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, cast import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, OnlineResource, - DatasetType, KaggleDownloadResource, ) from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, - image_buffer_from_array, ) from torchvision.prototype.features import Label, Image @@ -25,7 +20,6 @@ class FER2013(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "fer2013", - type=DatasetType.RAW, homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), valid_options=dict(split=("train", "test")), @@ -44,26 +38,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [archive] - def _collate_and_decode_sample( - self, - data: Dict[str, Any], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: - raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48) + def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: label_id = data.get("emotion") - label_idx = int(label_id) if label_id is not None else None - - image: Union[Image, io.BytesIO] - if decoder is raw: - image = Image(raw_image) - else: - image_buffer = image_buffer_from_array(raw_image.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] return dict( - image=image, - label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None, + image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), + label=Label(int(label_id), categories=self.categories) if label_id is not None else None, ) def _make_datapipe( @@ -71,10 +51,9 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVDictParser(dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 08855b3a2bd..2288766c10f 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,16 +1,12 @@ -import io import pathlib -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, OnlineResource, - DatasetType, HttpResource, ) from torchvision.prototype.datasets.utils._internal import ( @@ -19,14 +15,13 @@ hint_shuffling, INFINITE_BUFFER_SIZE, ) -from torchvision.prototype.features import Label, BoundingBox +from torchvision.prototype.features import Label, BoundingBox, EncodedImage class GTSRB(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "gtsrb", - type=DatasetType.IMAGE, homepage="https://benchmark.ini.rub.de", categories=[f"{label:05d}" for label in range(43)], valid_options=dict(split=("train", "test")), @@ -66,33 +61,26 @@ def _classify_train_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _collate_and_decode( - self, data: Tuple[Tuple[str, Any], Dict[str, Any]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: - (image_path, image_buffer), csv_info = data + def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[str, Any]: + (path, buffer), csv_info = data label = int(csv_info["ClassId"]) - bbox = BoundingBox( - torch.tensor([int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")]), + bounding_box = BoundingBox( + [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], format="xyxy", image_size=(int(csv_info["Height"]), int(csv_info["Width"])), ) return { - "image_path": image_path, - "image": decoder(image_buffer) if decoder else image_buffer, - "label": Label(label, category=self.categories[label]), - "bbox": bbox, + "path": path, + "image": EncodedImage.from_file(buffer), + "label": Label(label, categories=self.categories), + "bounding_box": bounding_box, } def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: - if config.split == "train": images_dp, ann_dp = Demultiplexer( resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE @@ -101,13 +89,12 @@ def _make_datapipe( images_dp, ann_dp = resource_dps images_dp = Filter(images_dp, path_comparator("suffix", ".ppm")) - # The order of the image files in the the .zip archives perfectly match the order of the entries in - # the (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper. + # The order of the image files in the .zip archives perfectly match the order of the entries in the + # (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper. ann_dp = CSVDictParser(ann_dp, delimiter=";") dp = Zipper(images_dp, ann_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, partial(self._collate_and_decode, decoder=decoder)) - return dp + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index ac3649c8839..0d11b642c13 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,18 +1,16 @@ import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast -import torch -from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter +from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer +from torchdata.datapipes.iter import TarArchiveReader from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, OnlineResource, ManualDownloadResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -24,7 +22,7 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.utils._internal import FrozenMapping @@ -40,7 +38,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("scipy",), categories=categories, homepage="https://www.image-net.org/", @@ -61,14 +58,6 @@ def _make_info(self) -> DatasetInfo: def supports_sharded(self) -> bool: return True - @property - def category_to_wnid(self) -> Dict[str, str]: - return cast(Dict[str, str], self.info.extra.category_to_wnid) - - @property - def wnid_to_category(self) -> Dict[str, str]: - return cast(Dict[str, str], self.info.extra.wnid_to_category) - _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", "val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0", @@ -77,23 +66,56 @@ def wnid_to_category(self) -> Dict[str, str]: def resources(self, config: DatasetConfig) -> List[OnlineResource]: name = "test_v10102019" if config.split == "test" else config.split - images = ImageNetResource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name]) - - devkit = ImageNetResource( - file_name="ILSVRC2012_devkit_t12.tar.gz", - sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", + images = ImageNetResource( + file_name=f"ILSVRC2012_img_{name}.tar", + sha256=self._IMAGES_CHECKSUMS[name], ) + resources: List[OnlineResource] = [images] + + if config.split == "val": + devkit = ImageNetResource( + file_name="ILSVRC2012_devkit_t12.tar.gz", + sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", + ) + resources.append(devkit) + + return resources - return [images, devkit] + def num_samples(self, config: DatasetConfig) -> int: + return { + "train": 1_281_167, + "val": 50_000, + "test": 100_000, + }[config.split] _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) - wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] - category = self.wnid_to_category[wnid] - label_data = (Label(self.categories.index(category)), category, wnid) - return label_data, data + wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] + label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + return (label, wnid), data + + def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: + return None, data + + def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: + return { + "meta.mat": 0, + "ILSVRC2012_validation_ground_truth.txt": 1, + }.get(pathlib.Path(data[0]).name) + + def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: + synsets = read_mat(data[1], squeeze_me=True)["synsets"] + return [ + (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) + for _, wnid, category, _, num_children, *_ in synsets + # if num_children > 0, we are looking at a superclass that has no direct instance + if num_children == 0 + ] + + def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str: + return wnids[int(imagenet_label) - 1] _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") @@ -101,72 +123,65 @@ def _val_test_image_key(self, data: Tuple[str, Any]) -> int: path = pathlib.Path(data[0]) return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] - def _collate_val_data( - self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] - ) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _prepare_val_data( + self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] + ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: label_data, image_data = data - _, label = label_data - category = self.categories[label] - wnid = self.category_to_wnid[category] - return (Label(label), category, wnid), image_data - - def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]: - return None, data + _, wnid = label_data + label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + return (label, wnid), image_data - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], ) -> Dict[str, Any]: label_data, (path, buffer) = data - sample = dict( + return dict( + dict(zip(("label", "wnid"), label_data if label_data else (None, None))), path=path, - image=decoder(buffer) if decoder else buffer, + image=EncodedImage.from_file(buffer), ) - if label_data: - sample.update(dict(zip(("label", "category", "wnid"), label_data))) - - return sample def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: - images_dp, devkit_dp = resource_dps + if config.split in {"train", "test"}: + dp = resource_dps[0] - if config.split == "train": # the train archive is a tar of tars - dp = TarArchiveReader(images_dp) + if config.split == "train": + dp = TarArchiveReader(dp) + dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_train_data) - elif config.split == "val": - devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) - devkit_dp = LineReader(devkit_dp, return_path=False) - devkit_dp = Mapper(devkit_dp, int) - devkit_dp = Enumerator(devkit_dp, 1) - devkit_dp = hint_sharding(devkit_dp) - devkit_dp = hint_shuffling(devkit_dp) + dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) + else: # config.split == "val": + images_dp, devkit_dp = resource_dps + + meta_dp, label_dp = Demultiplexer( + devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + ) + + meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + _, wnids = zip(*next(iter(meta_dp))) + + label_dp = LineReader(label_dp, decode=True, return_path=False) + label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) + label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) + label_dp = hint_sharding(label_dp) + label_dp = hint_shuffling(label_dp) dp = IterKeyZipper( - devkit_dp, + label_dp, images_dp, key_fn=getitem(0), ref_key_fn=self._val_test_image_key, buffer_size=INFINITE_BUFFER_SIZE, ) - dp = Mapper(dp, self._collate_val_data) - else: # config.split == "test" - dp = hint_sharding(images_dp) - dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_test_data) + dp = Mapper(dp, self._prepare_val_data) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment @@ -176,22 +191,13 @@ def _make_datapipe( } def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: - resources = self.resources(self.default_config) + config = self.info.make_config(split="val") + resources = self.resources(config) devkit_dp = resources[1].load(root) - devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) - - meta = next(iter(devkit_dp))[1] - synsets = read_mat(meta, squeeze_me=True)["synsets"] - categories_and_wnids = cast( - List[Tuple[str, ...]], - [ - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) - for _, wnid, category, _, num_children, *_ in synsets - # if num_children > 0, we are looking at a superclass that has no direct instance - if num_children == 0 - ], - ) - categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) + meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) + meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) + categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) return categories_and_wnids diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index ee9b2a65b51..e5b9fa84b0d 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -1,10 +1,9 @@ import abc import functools -import io import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence import torch from torchdata.datapipes.iter import ( @@ -13,17 +12,14 @@ Mapper, Zipper, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, - DatasetType, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import ( - image_buffer_from_array, Decompressor, INFINITE_BUFFER_SIZE, hint_sharding, @@ -105,31 +101,15 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: return None, None - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, label = data - - if decoder is raw: - image = Image(image) - else: - image_buffer = image_buffer_from_array(image.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)]) - - return dict(image=image, label=label) + return dict( + image=Image(image), + label=Label(label, dtype=torch.int64, categories=self.categories), + ) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps start, stop = self.start_and_stop(config) @@ -143,14 +123,13 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._prepare_sample, config=config)) class MNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "mnist", - type=DatasetType.RAW, categories=10, homepage="http://yann.lecun.com/exdb/mnist", valid_options=dict( @@ -183,7 +162,6 @@ class FashionMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "fashionmnist", - type=DatasetType.RAW, categories=( "T-shirt/top", "Trouser", @@ -215,7 +193,6 @@ class KMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "kmnist", - type=DatasetType.RAW, categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], homepage="http://codh.rois.ac.jp/kmnist/index.html.en", valid_options=dict( @@ -236,7 +213,6 @@ class EMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "emnist", - type=DatasetType.RAW, categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", valid_options=dict( @@ -291,13 +267,7 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, @@ -310,14 +280,10 @@ def _collate_and_decode( image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._collate_and_decode(data, config=config, decoder=decoder) + return super()._prepare_sample(data, config=config) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( @@ -327,14 +293,13 @@ def _make_datapipe( drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder) + return super()._make_datapipe([images_dp, labels_dp], config=config) class QMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "qmnist", - type=DatasetType.RAW, categories=10, homepage="https://github.com/facebookresearch/qmnist", valid_options=dict( @@ -376,16 +341,10 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) + sample = super()._prepare_sample((image, label), config=config) sample.update( dict( diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 59a28796cbc..1780b8829f4 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,10 +1,7 @@ import enum -import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( Dataset, @@ -12,7 +9,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -22,7 +18,7 @@ path_accessor, path_comparator, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage class OxfordIITPetDemux(enum.IntEnum): @@ -34,7 +30,6 @@ class OxfordIITPet(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "oxford-iiit-pet", - type=DatasetType.IMAGE, homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", valid_options=dict( split=("trainval", "test"), @@ -66,18 +61,8 @@ def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_segmentations(self, data: Tuple[str, Any]) -> bool: return not pathlib.Path(data[0]).name.startswith(".") - def _decode_classification_data(self, data: Dict[str, str]) -> Dict[str, Any]: - label_idx = int(data["label"]) - 1 - return dict( - label=Label(label_idx, category=self.info.categories[label_idx]), - species="cat" if data["species"] == "1" else "dog", - ) - - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[Dict[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _prepare_sample( + self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]] ) -> Dict[str, Any]: ann_data, image_data = data classification_data, segmentation_data = ann_data @@ -85,19 +70,16 @@ def _collate_and_decode_sample( image_path, image_buffer = image_data return dict( - self._decode_classification_data(classification_data), + label=Label(int(classification_data["label"]) - 1, categories=self.categories), + species="cat" if classification_data["species"] == "1" else "dog", segmentation_path=segmentation_path, - segmentation=decoder(segmentation_buffer) if decoder else segmentation_buffer, + segmentation=EncodedImage.from_file(segmentation_buffer), image_path=image_path, - image=decoder(image_buffer) if decoder else image_buffer, + image=EncodedImage.from_file(image_buffer), ) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps @@ -137,7 +119,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 27b27b2745b..e019b765a30 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,11 +1,8 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +17,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -31,20 +27,17 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import Feature +from torchvision.prototype.features import Feature, EncodedImage class SBD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "sbd", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", valid_options=dict( split=("train", "val", "train_noval"), - boundaries=(True, False), - segmentation=(False, True), ), ) @@ -75,50 +68,21 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _decode_ann( - self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - raw_anns = data["GTcls"][0] - raw_boundaries = raw_anns["Boundaries"][0] - raw_segmentation = raw_anns["Segmentation"][0] - - # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries = ( - Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries])) - if decode_boundaries - else None - ) - segmentation = Feature(raw_segmentation) if decode_segmentation else None - - return boundaries, segmentation - - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.boundaries or config.segmentation: - boundaries, segmentation = self._decode_ann( - read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation - ) - else: - boundaries = segmentation = None + anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"] return dict( image_path=image_path, - image=image, + image=EncodedImage.from_file(image_buffer), ann_path=ann_path, - boundaries=boundaries, - segmentation=segmentation, + # the boundaries are stored in sparse CSC format, which is not supported by PyTorch + boundaries=Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])), + segmentation=Feature(anns["Segmentation"].item()), ) def _make_datapipe( @@ -126,7 +90,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps @@ -138,10 +101,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) - if config.split == "train_noval": split_dp = extra_split_dp - split_dp = Filter(split_dp, path_comparator("stem", config.split)) + + split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -155,7 +118,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: resources = self.resources(self.default_config) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index d153debcefd..a6fc1098fda 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,6 +1,4 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch from torchdata.datapipes.iter import ( @@ -8,24 +6,21 @@ Mapper, CSVParser, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) -from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling -from torchvision.prototype.features import Image, Label +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, OneHotLabel class SEMEION(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "semeion", - type=DatasetType.RAW, categories=10, homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", ) @@ -37,34 +32,22 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [data] - def _collate_and_decode_sample( - self, - data: Tuple[str, ...], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: - image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16) - label_data = [int(label) for label in data[256:] if label] + def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: + image_data, label_data = data[:256], data[256:-1] - if decoder is raw: - image = Image(image_data.unsqueeze(0)) - else: - image_buffer = image_buffer_from_array(image_data.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) - return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx])) + return dict( + image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)), + label=OneHotLabel([int(label) for label in label_data], categories=self.categories), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) - return dp + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 7f9c019e92e..21af4add909 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -1,28 +1,22 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, UnBatcher, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( read_mat, hint_sharding, hint_shuffling, - image_buffer_from_array, ) from torchvision.prototype.features import Label, Image @@ -31,7 +25,6 @@ class SVHN(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "svhn", - type=DatasetType.RAW, dependencies=("scipy",), categories=10, homepage="http://ufldl.stanford.edu/housenumbers/", @@ -52,7 +45,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [data] - def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]: + def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]: _, buffer = data content = read_mat(buffer) return list( @@ -62,23 +55,12 @@ def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np. ) ) - def _collate_and_decode_sample( - self, - data: Tuple[np.ndarray, np.ndarray], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]: image_array, label_array = data - if decoder is raw: - image = Image(image_array.transpose((2, 0, 1))) - else: - image_buffer = image_buffer_from_array(image_array) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - return dict( - image=image, - label=Label(int(label_array) % 10), + image=Image(image_array.transpose((2, 0, 1))), + label=Label(int(label_array) % 10, categories=self.categories), ) def _make_datapipe( @@ -86,11 +68,10 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Mapper(dp, self._read_images_and_labels) dp = UnBatcher(dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/voc.categories b/torchvision/prototype/datasets/_builtin/voc.categories new file mode 100644 index 00000000000..8420ab35ede --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/voc.categories @@ -0,0 +1,20 @@ +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index da145ab1e1c..6ba2186853d 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,10 +1,8 @@ import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable from xml.etree import ElementTree -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +18,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( path_accessor, @@ -30,7 +27,7 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import BoundingBox +from torchvision.prototype.features import BoundingBox, Label, EncodedImage class VOCDatasetInfo(DatasetInfo): @@ -50,7 +47,6 @@ class VOC(Dataset): def _make_info(self) -> DatasetInfo: return VOCDatasetInfo( "voc", - type=DatasetType.IMAGE, homepage="http://host.robots.ox.ac.uk/pascal/VOC/", valid_options=dict( split=("train", "val", "trainval", "test"), @@ -99,40 +95,52 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None - def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor: - result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type] - objects = result["annotation"]["object"] - bboxes = [obj["bndbox"] for obj in objects] - bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes] - return BoundingBox(bboxes) + def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"]) + + def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + anns = self._parse_detection_ann(buffer) + instances = anns["object"] + return dict( + bounding_boxes=BoundingBox( + [ + [int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for instance in instances + ], + format="xyxy", + image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), + ), + labels=Label( + [self.categories.index(instance["name"]) for instance in instances], categories=self.categories + ), + ) + + def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + return dict(segmentation=EncodedImage.from_file(buffer)) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.task == "detection": - ann = self._decode_detection_ann(ann_buffer) - else: # config.task == "segmentation": - ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment] - - return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann) + return dict( + prepare_ann_fn(ann_buffer), + image_path=image_path, + image=EncodedImage.from_file(image_buffer), + ann_path=ann_path, + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( @@ -158,4 +166,25 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) + return Mapper( + dp, + functools.partial( + self._prepare_sample, + prepare_ann_fn=self._prepare_detection_ann + if config.task == "detection" + else self._prepare_segmentation_ann, + ), + ) + + def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: + return self._classify_archive(data, config=config) == 2 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.info.make_config(task="detection") + + resource = self.resources(config)[0] + dp = resource.load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + dp = Mapper(dp, self._parse_detection_ann, input_col=1) + + return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index fbca8b07b1a..c3a38becb6c 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -1,15 +1,12 @@ import functools -import io import os import os.path import pathlib -from typing import Callable, Optional, Collection -from typing import Union, Tuple, List, Dict, Any +from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any -import torch -from torchdata.datapipes.iter import IterDataPipe, FileLister, FileOpener, Mapper, Shuffler, Filter -from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding +from torchdata.datapipes.iter import IterDataPipe, FileLister, Mapper, Filter, FileOpener +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Label, EncodedImage, EncodedData __all__ = ["from_data_folder", "from_image_folder"] @@ -20,29 +17,24 @@ def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") -def _collate_and_decode_data( - data: Tuple[str, io.IOBase], +def _prepare_sample( + data: Tuple[str, BinaryIO], *, root: pathlib.Path, categories: List[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: path, buffer = data - data = decoder(buffer) if decoder else buffer category = pathlib.Path(path).relative_to(root).parts[0] - label = torch.tensor(categories.index(category)) return dict( path=path, - data=data, - label=label, - category=category, + data=EncodedData.from_file(buffer), + label=Label.from_category(category, categories=categories), ) def from_data_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, valid_extensions: Optional[Collection[str]] = None, recursive: bool = True, ) -> Tuple[IterDataPipe, List[str]]: @@ -52,26 +44,22 @@ def from_data_folder( dp = FileLister(str(root), recursive=recursive, masks=masks) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp = hint_sharding(dp) - dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + dp = hint_shuffling(dp) dp = FileOpener(dp, mode="rb") - return ( - Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)), - categories, - ) + return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = sample.pop("data") + sample["image"] = EncodedImage(sample.pop("data").data) return sample def from_image_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) + dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs) return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py deleted file mode 100644 index 530a357f239..00000000000 --- a/torchvision/prototype/datasets/decoder.py +++ /dev/null @@ -1,16 +0,0 @@ -import io - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.transforms.functional import pil_to_tensor - -__all__ = ["raw", "pil"] - - -def raw(buffer: io.IOBase) -> torch.Tensor: - raise RuntimeError("This is just a sentinel and should never be called.") - - -def pil(buffer: io.IOBase) -> features.Image: - return features.Image(pil_to_tensor(PIL.Image.open(buffer))) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index bde05c49cb1..9423b65a8ee 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ -from . import _internal -from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset +from . import _internal # usort: skip +from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 38c991fe7a1..37768d150e6 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,28 +1,19 @@ import abc import csv -import enum import importlib -import io import itertools import os import pathlib -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple, Collection +from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection -import torch from torch.utils.data import IterDataPipe -from torchvision.prototype.utils._internal import FrozenBunch, make_repr -from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str +from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str from .._home import use_sharded_dataset from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._resource import OnlineResource -class DatasetType(enum.Enum): - RAW = enum.auto() - IMAGE = enum.auto() - - class DatasetConfig(FrozenBunch): pass @@ -32,7 +23,6 @@ def __init__( self, name: str, *, - type: Union[str, DatasetType], dependencies: Collection[str] = (), categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, @@ -42,7 +32,6 @@ def __init__( extra: Optional[Dict[str, Any]] = None, ) -> None: self.name = name.lower() - self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.dependecies = dependencies @@ -161,7 +150,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: pass @@ -173,7 +161,6 @@ def load( root: Union[str, pathlib.Path], *, config: Optional[DatasetConfig] = None, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, skip_integrity_check: bool = False, ) -> IterDataPipe[Dict[str, Any]]: if not config: @@ -188,7 +175,7 @@ def load( resource_dps = [ resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) ] - return self._make_datapipe(resource_dps, config=config, decoder=decoder) + return self._make_datapipe(resource_dps, config=config) def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 7106ea44a44..4d4cbb65cd7 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,7 +1,6 @@ import enum import functools import gzip -import io import lzma import os import os.path @@ -23,8 +22,6 @@ ) from typing import cast -import numpy as np -import PIL.Image import torch import torch.distributed as dist import torch.utils.data @@ -37,7 +34,6 @@ "INFINITE_BUFFER_SIZE", "BUILTIN_DIR", "read_mat", - "image_buffer_from_array", "SequenceIterator", "MappingIterator", "Enumerator", @@ -58,7 +54,7 @@ BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" -def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: +def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any: try: import scipy.io as sio except ImportError as error: @@ -70,14 +66,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: return sio.loadmat(buffer, **kwargs) -def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: - image = PIL.Image.fromarray(array) - buffer = io.BytesIO() - image.save(buffer, format=format) - buffer.seek(0) - return buffer - - class SequenceIterator(IterDataPipe[D]): def __init__(self, datapipe: IterDataPipe[Sequence[D]]): self.datapipe = datapipe @@ -150,17 +138,17 @@ class CompressionType(enum.Enum): LZMA = "lzma" -class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): +class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]): types = CompressionType - _DECOMPRESSORS = { - types.GZIP: lambda file: gzip.GzipFile(fileobj=file), - types.LZMA: lambda file: lzma.LZMAFile(file), + _DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = { + types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)), + types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)), } def __init__( self, - datapipe: IterDataPipe[Tuple[str, io.IOBase]], + datapipe: IterDataPipe[Tuple[str, BinaryIO]], *, type: Optional[Union[str, CompressionType]] = None, ) -> None: @@ -182,7 +170,7 @@ def _detect_compression_type(self, path: str) -> CompressionType: else: raise RuntimeError("FIXME") - def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: + def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]: for path, file in self.datapipe: type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[type] @@ -274,9 +262,9 @@ def read_flo(file: BinaryIO) -> torch.Tensor: return flow.reshape((height, width, 2)).permute((2, 0, 1)) -def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: +def hint_sharding(datapipe: IterDataPipe[D]) -> ShardingFilter[D]: return ShardingFilter(datapipe) -def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: +def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 085b39204dd..38fff2da04a 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,7 +1,7 @@ -from typing import Any, Callable, cast, Dict, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar +from typing import Any, cast, Dict, Set, TypeVar import torch -from torch._C import _TensorBase, DisableTorchFunction +from torch._C import _TensorBase F = TypeVar("F", bound="Feature")