Skip to content

Commit 1691e72

Browse files
authored
migrate cub200 prototype dataset (#5765)
* migrate cub200 prototype dataset * address comments * fix category-file-generation
1 parent 3b10147 commit 1691e72

File tree

2 files changed

+66
-37
lines changed

2 files changed

+66
-37
lines changed

test/builtin_dataset_mocks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,10 +1402,10 @@ def generate(cls, root):
14021402
return num_samples_map
14031403

14041404

1405-
# @register_mock
1406-
def cub200(info, root, config):
1407-
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
1408-
return num_samples_map[config.split]
1405+
@register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011")))
1406+
def cub200(root, config):
1407+
num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root)
1408+
return num_samples_map[config["split"]]
14091409

14101410

14111411
@register_mock(configs=[dict()])

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22
import functools
33
import pathlib
4-
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable
4+
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union
55

66
from torchdata.datapipes.iter import (
77
IterDataPipe,
@@ -14,8 +14,7 @@
1414
CSVDictParser,
1515
)
1616
from torchvision.prototype.datasets.utils import (
17-
Dataset,
18-
DatasetConfig,
17+
Dataset2,
1918
DatasetInfo,
2019
HttpResource,
2120
OnlineResource,
@@ -28,26 +27,53 @@
2827
getitem,
2928
path_comparator,
3029
path_accessor,
30+
BUILTIN_DIR,
3131
)
3232
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
3333

34+
from .._api import register_dataset, register_info
35+
3436
csv.register_dialect("cub200", delimiter=" ")
3537

3638

37-
class CUB200(Dataset):
38-
def _make_info(self) -> DatasetInfo:
39-
return DatasetInfo(
40-
"cub200",
41-
homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html",
42-
dependencies=("scipy",),
43-
valid_options=dict(
44-
split=("train", "test"),
45-
year=("2011", "2010"),
46-
),
39+
NAME = "cub200"
40+
41+
CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
42+
43+
44+
@register_info(NAME)
45+
def _info() -> Dict[str, Any]:
46+
return dict(categories=CATEGORIES)
47+
48+
49+
@register_dataset(NAME)
50+
class CUB200(Dataset2):
51+
"""
52+
- **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
53+
"""
54+
55+
def __init__(
56+
self,
57+
root: Union[str, pathlib.Path],
58+
*,
59+
split: str = "train",
60+
year: str = "2011",
61+
skip_integrity_check: bool = False,
62+
) -> None:
63+
self._split = self._verify_str_arg(split, "split", ("train", "test"))
64+
self._year = self._verify_str_arg(year, "year", ("2010", "2011"))
65+
66+
self._categories = _info()["categories"]
67+
68+
super().__init__(
69+
root,
70+
# TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
71+
# dependencies=("scipy",),
72+
skip_integrity_check=skip_integrity_check,
4773
)
4874

49-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
50-
if config.year == "2011":
75+
def _resources(self) -> List[OnlineResource]:
76+
if self._year == "2011":
5177
archive = HttpResource(
5278
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz",
5379
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
@@ -59,7 +85,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5985
preprocess="decompress",
6086
)
6187
return [archive, segmentations]
62-
else: # config.year == "2010"
88+
else: # self._year == "2010"
6389
split = HttpResource(
6490
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz",
6591
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
@@ -90,12 +116,12 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
90116
else:
91117
return None
92118

93-
def _2011_filter_split(self, row: List[str], *, split: str) -> bool:
119+
def _2011_filter_split(self, row: List[str]) -> bool:
94120
_, split_id = row
95121
return {
96122
"0": "test",
97123
"1": "train",
98-
}[split_id] == split
124+
}[split_id] == self._split
99125

100126
def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
101127
path = pathlib.Path(data[0])
@@ -149,17 +175,12 @@ def _prepare_sample(
149175
return dict(
150176
prepare_ann_fn(anns_data, image.image_size),
151177
image=image,
152-
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories),
178+
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories),
153179
)
154180

155-
def _make_datapipe(
156-
self,
157-
resource_dps: List[IterDataPipe],
158-
*,
159-
config: DatasetConfig,
160-
) -> IterDataPipe[Dict[str, Any]]:
181+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
161182
prepare_ann_fn: Callable
162-
if config.year == "2011":
183+
if self._year == "2011":
163184
archive_dp, segmentations_dp = resource_dps
164185
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
165186
archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
@@ -171,7 +192,7 @@ def _make_datapipe(
171192
)
172193

173194
split_dp = CSVParser(split_dp, dialect="cub200")
174-
split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split))
195+
split_dp = Filter(split_dp, self._2011_filter_split)
175196
split_dp = Mapper(split_dp, getitem(0))
176197
split_dp = Mapper(split_dp, image_files_map.get)
177198

@@ -188,10 +209,10 @@ def _make_datapipe(
188209
)
189210

190211
prepare_ann_fn = self._2011_prepare_ann
191-
else: # config.year == "2010"
212+
else: # self._year == "2010"
192213
split_dp, images_dp, anns_dp = resource_dps
193214

194-
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
215+
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
195216
split_dp = LineReader(split_dp, decode=True, return_path=False)
196217
split_dp = Mapper(split_dp, self._2010_split_key)
197218

@@ -217,11 +238,19 @@ def _make_datapipe(
217238
)
218239
return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn))
219240

220-
def _generate_categories(self, root: pathlib.Path) -> List[str]:
221-
config = self.info.make_config(year="2011")
222-
resources = self.resources(config)
241+
def __len__(self) -> int:
242+
return {
243+
("train", "2010"): 3_000,
244+
("test", "2010"): 3_033,
245+
("train", "2011"): 5_994,
246+
("test", "2011"): 5_794,
247+
}[(self._split, self._year)]
248+
249+
def _generate_categories(self) -> List[str]:
250+
self._year = "2011"
251+
resources = self._resources()
223252

224-
dp = resources[0].load(root)
253+
dp = resources[0].load(self._root)
225254
dp = Filter(dp, path_comparator("name", "classes.txt"))
226255
dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200")
227256

0 commit comments

Comments
 (0)