diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index cc8568154ed..5c9e657c2c7 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -971,8 +971,8 @@ def food101(info, root, config): return num_samples_map[config.split] -# @register_mock -def dtd(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10))) +def dtd(root, config): data_folder = root / "dtd" num_images_per_class = 3 @@ -1012,11 +1012,11 @@ def dtd(info, root, config): with open(meta_folder / f"{split}{fold}.txt", "w") as file: file.write("\n".join(image_ids_in_config) + "\n") - num_samples_map[info.make_config(split=split, fold=str(fold))] = len(image_ids_in_config) + num_samples_map[(split, fold)] = len(image_ids_in_config) make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") - return num_samples_map[config] + return num_samples_map[config["split"], config["fold"]] # @register_mock diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 682fed2d9c2..a5de1359e4e 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,11 +1,10 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, + Dataset2, DatasetInfo, HttpResource, OnlineResource, @@ -14,11 +13,17 @@ INFINITE_BUFFER_SIZE, hint_sharding, path_comparator, + BUILTIN_DIR, getitem, hint_shuffling, ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + + +NAME = "dtd" + class DTDDemux(enum.IntEnum): SPLIT = 0 @@ -26,18 +31,37 @@ class DTDDemux(enum.IntEnum): IMAGES = 2 -class DTD(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "dtd", - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", - valid_options=dict( - split=("train", "test", "val"), - fold=tuple(str(fold) for fold in range(1, 11)), - ), - ) +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = [c[0] for c in categories] + return dict(categories=categories) + + +@register_dataset(NAME) +class DTD(Dataset2): + """DTD Dataset. + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + """ + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + fold: int = 1, + skip_validation_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + + if not (1 <= fold <= 10): + raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}") + self._fold = fold + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_validation_check) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", @@ -71,24 +95,19 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO return dict( joint_categories={category for category in joint_categories if category}, - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] splits_dp, joint_categories_dp, images_dp = Demultiplexer( archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) + splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt")) splits_dp = LineReader(splits_dp, decode=True, return_path=False) splits_dp = hint_shuffling(splits_dp) splits_dp = hint_sharding(splits_dp) @@ -114,10 +133,13 @@ def _make_datapipe( def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def _generate_categories(self) -> List[str]: + resources = self.resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, self._filter_images) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) + + def __len__(self) -> int: + return 1_880 # All splits have the same length