From f1631e1137d3793d51a4fac5b3d201c589f40adb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:44:09 +0100 Subject: [PATCH] Migrate USPS prototype dataset --- test/builtin_dataset_mocks.py | 8 +-- .../prototype/datasets/_builtin/usps.py | 57 ++++++++++++------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index b33dc1450e3..89c1cf5033c 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1515,11 +1515,11 @@ def stanford_cars(root, config): return num_samples -# @register_mock -def usps(info, root, config): - num_samples = {"train": 15, "test": 7}[config.split] +@register_mock(configs=combinations_grid(split=("train", "test"))) +def usps(root, config): + num_samples = {"train": 15, "test": 7}[config["split"]] - with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: + with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh: lines = [] for _ in range(num_samples): label = make_tensor(1, low=1, high=11, dtype=torch.int) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 155fbff5dbb..e1c9940ed86 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -1,22 +1,39 @@ -from typing import Any, Dict, List +import pathlib +from typing import Any, Dict, List, Union import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils import Dataset2, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label +from .._api import register_dataset, register_info -class USPS(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "usps", - homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", - valid_options=dict( - split=("train", "test"), - ), - categories=10, - ) +NAME = "usps" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(c) for c in range(10)]) + + +@register_dataset(NAME) +class USPS(Dataset2): + """USPS Dataset + homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" @@ -29,8 +46,8 @@ def _make_info(self) -> DatasetInfo: ), } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - return [USPS._RESOURCES[config.split]] + def _resources(self) -> List[OnlineResource]: + return [USPS._RESOURCES[self._split]] def _prepare_sample(self, line: str) -> Dict[str, Any]: label, *values = line.strip().split(" ") @@ -38,17 +55,15 @@ def _prepare_sample(self, line: str) -> Dict[str, Any]: pixels = torch.tensor(values).add_(1).div_(2) return dict( image=Image(pixels.reshape(16, 16)), - label=Label(int(label) - 1, categories=self.categories), + label=Label(int(label) - 1, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = Decompressor(resource_dps[0]) dp = LineReader(dp, decode=True, return_path=False) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 7_291 if self._split == "train" else 2_007