Skip to content

Commit 2a212b8

Browse files
authored
Migrate USPS prototype dataset (#5771)
1 parent 1691e72 commit 2a212b8

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

test/builtin_dataset_mocks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,11 +1515,11 @@ def stanford_cars(root, config):
15151515
return num_samples
15161516

15171517

1518-
# @register_mock
1519-
def usps(info, root, config):
1520-
num_samples = {"train": 15, "test": 7}[config.split]
1518+
@register_mock(configs=combinations_grid(split=("train", "test")))
1519+
def usps(root, config):
1520+
num_samples = {"train": 15, "test": 7}[config["split"]]
15211521

1522-
with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh:
1522+
with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh:
15231523
lines = []
15241524
for _ in range(num_samples):
15251525
label = make_tensor(1, low=1, high=11, dtype=torch.int)
Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,39 @@
1-
from typing import Any, Dict, List
1+
import pathlib
2+
from typing import Any, Dict, List, Union
23

34
import torch
45
from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor
5-
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
6+
from torchvision.prototype.datasets.utils import Dataset2, OnlineResource, HttpResource
67
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
78
from torchvision.prototype.features import Image, Label
89

10+
from .._api import register_dataset, register_info
911

10-
class USPS(Dataset):
11-
def _make_info(self) -> DatasetInfo:
12-
return DatasetInfo(
13-
"usps",
14-
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
15-
valid_options=dict(
16-
split=("train", "test"),
17-
),
18-
categories=10,
19-
)
12+
NAME = "usps"
13+
14+
15+
@register_info(NAME)
16+
def _info() -> Dict[str, Any]:
17+
return dict(categories=[str(c) for c in range(10)])
18+
19+
20+
@register_dataset(NAME)
21+
class USPS(Dataset2):
22+
"""USPS Dataset
23+
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
24+
"""
25+
26+
def __init__(
27+
self,
28+
root: Union[str, pathlib.Path],
29+
*,
30+
split: str = "train",
31+
skip_integrity_check: bool = False,
32+
) -> None:
33+
self._split = self._verify_str_arg(split, "split", {"train", "test"})
34+
35+
self._categories = _info()["categories"]
36+
super().__init__(root, skip_integrity_check=skip_integrity_check)
2037

2138
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
2239

@@ -29,26 +46,24 @@ def _make_info(self) -> DatasetInfo:
2946
),
3047
}
3148

32-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
33-
return [USPS._RESOURCES[config.split]]
49+
def _resources(self) -> List[OnlineResource]:
50+
return [USPS._RESOURCES[self._split]]
3451

3552
def _prepare_sample(self, line: str) -> Dict[str, Any]:
3653
label, *values = line.strip().split(" ")
3754
values = [float(value.split(":")[1]) for value in values]
3855
pixels = torch.tensor(values).add_(1).div_(2)
3956
return dict(
4057
image=Image(pixels.reshape(16, 16)),
41-
label=Label(int(label) - 1, categories=self.categories),
58+
label=Label(int(label) - 1, categories=self._categories),
4259
)
4360

44-
def _make_datapipe(
45-
self,
46-
resource_dps: List[IterDataPipe],
47-
*,
48-
config: DatasetConfig,
49-
) -> IterDataPipe[Dict[str, Any]]:
61+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
5062
dp = Decompressor(resource_dps[0])
5163
dp = LineReader(dp, decode=True, return_path=False)
5264
dp = hint_shuffling(dp)
5365
dp = hint_sharding(dp)
5466
return Mapper(dp, self._prepare_sample)
67+
68+
def __len__(self) -> int:
69+
return 7_291 if self._split == "train" else 2_007

0 commit comments

Comments
 (0)