Skip to content

Commit 2ed549d

Browse files
authored
migrate CLEVR prototype datsaet (#5752)
1 parent 6a0592f commit 2ed549d

File tree

2 files changed

+34
-28
lines changed

2 files changed

+34
-28
lines changed

test/builtin_dataset_mocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,8 +1108,8 @@ def _make_ann_file(path, num_examples, class_idx):
11081108
return num_examples
11091109

11101110

1111-
# @register_mock
1112-
def clevr(info, root, config):
1111+
@register_mock(configs=combinations_grid(split=("train", "val", "test")))
1112+
def clevr(root, config):
11131113
data_folder = root / "CLEVR_v1.0"
11141114

11151115
num_samples_map = {
@@ -1150,7 +1150,7 @@ def clevr(info, root, config):
11501150

11511151
make_zip(root, f"{data_folder.name}.zip", data_folder)
11521152

1153-
return num_samples_map[config.split]
1153+
return num_samples_map[config["split"]]
11541154

11551155

11561156
class OxfordIIITPetMockData:

torchvision/prototype/datasets/_builtin/clevr.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
import pathlib
2-
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
2+
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union
33

44
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
5-
from torchvision.prototype.datasets.utils import (
6-
Dataset,
7-
DatasetConfig,
8-
DatasetInfo,
9-
HttpResource,
10-
OnlineResource,
11-
)
5+
from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource
126
from torchvision.prototype.datasets.utils._internal import (
137
INFINITE_BUFFER_SIZE,
148
hint_sharding,
@@ -19,16 +13,30 @@
1913
)
2014
from torchvision.prototype.features import Label, EncodedImage
2115

16+
from .._api import register_dataset, register_info
17+
18+
NAME = "clevr"
19+
20+
21+
@register_info(NAME)
22+
def _info() -> Dict[str, Any]:
23+
return dict()
2224

23-
class CLEVR(Dataset):
24-
def _make_info(self) -> DatasetInfo:
25-
return DatasetInfo(
26-
"clevr",
27-
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
28-
valid_options=dict(split=("train", "val", "test")),
29-
)
3025

31-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
26+
@register_dataset(NAME)
27+
class CLEVR(Dataset2):
28+
"""
29+
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
30+
"""
31+
32+
def __init__(
33+
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
34+
) -> None:
35+
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
36+
37+
super().__init__(root, skip_integrity_check=skip_integrity_check)
38+
39+
def _resources(self) -> List[OnlineResource]:
3240
archive = HttpResource(
3341
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
3442
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
@@ -61,12 +69,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A
6169
label=Label(len(scenes_data["objects"])) if scenes_data else None,
6270
)
6371

64-
def _make_datapipe(
65-
self,
66-
resource_dps: List[IterDataPipe],
67-
*,
68-
config: DatasetConfig,
69-
) -> IterDataPipe[Dict[str, Any]]:
72+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
7073
archive_dp = resource_dps[0]
7174
images_dp, scenes_dp = Demultiplexer(
7275
archive_dp,
@@ -76,12 +79,12 @@ def _make_datapipe(
7679
buffer_size=INFINITE_BUFFER_SIZE,
7780
)
7881

79-
images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
82+
images_dp = Filter(images_dp, path_comparator("parent.name", self._split))
8083
images_dp = hint_shuffling(images_dp)
8184
images_dp = hint_sharding(images_dp)
8285

83-
if config.split != "test":
84-
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
86+
if self._split != "test":
87+
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json"))
8588
scenes_dp = JsonParser(scenes_dp)
8689
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
8790
scenes_dp = UnBatcher(scenes_dp)
@@ -97,3 +100,6 @@ def _make_datapipe(
97100
dp = Mapper(images_dp, self._add_empty_anns)
98101

99102
return Mapper(dp, self._prepare_sample)
103+
104+
def __len__(self) -> int:
105+
return 70_000 if self._split == "train" else 15_000

0 commit comments

Comments
 (0)