1
1
import pathlib
2
- from typing import Any , Dict , List , Optional , Tuple , BinaryIO
2
+ from typing import Any , Dict , List , Optional , Tuple , BinaryIO , Union
3
3
4
4
from torchdata .datapipes .iter import IterDataPipe , Mapper , Filter , IterKeyZipper , Demultiplexer , JsonParser , UnBatcher
5
- from torchvision .prototype .datasets .utils import (
6
- Dataset ,
7
- DatasetConfig ,
8
- DatasetInfo ,
9
- HttpResource ,
10
- OnlineResource ,
11
- )
5
+ from torchvision .prototype .datasets .utils import Dataset2 , HttpResource , OnlineResource
12
6
from torchvision .prototype .datasets .utils ._internal import (
13
7
INFINITE_BUFFER_SIZE ,
14
8
hint_sharding ,
19
13
)
20
14
from torchvision .prototype .features import Label , EncodedImage
21
15
16
+ from .._api import register_dataset , register_info
17
+
18
+ NAME = "clevr"
19
+
20
+
21
+ @register_info (NAME )
22
+ def _info () -> Dict [str , Any ]:
23
+ return dict ()
22
24
23
- class CLEVR (Dataset ):
24
- def _make_info (self ) -> DatasetInfo :
25
- return DatasetInfo (
26
- "clevr" ,
27
- homepage = "https://cs.stanford.edu/people/jcjohns/clevr/" ,
28
- valid_options = dict (split = ("train" , "val" , "test" )),
29
- )
30
25
31
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
26
+ @register_dataset (NAME )
27
+ class CLEVR (Dataset2 ):
28
+ """
29
+ - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
30
+ """
31
+
32
+ def __init__ (
33
+ self , root : Union [str , pathlib .Path ], * , split : str = "train" , skip_integrity_check : bool = False
34
+ ) -> None :
35
+ self ._split = self ._verify_str_arg (split , "split" , ("train" , "val" , "test" ))
36
+
37
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check )
38
+
39
+ def _resources (self ) -> List [OnlineResource ]:
32
40
archive = HttpResource (
33
41
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip" ,
34
42
sha256 = "5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1" ,
@@ -61,12 +69,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A
61
69
label = Label (len (scenes_data ["objects" ])) if scenes_data else None ,
62
70
)
63
71
64
- def _make_datapipe (
65
- self ,
66
- resource_dps : List [IterDataPipe ],
67
- * ,
68
- config : DatasetConfig ,
69
- ) -> IterDataPipe [Dict [str , Any ]]:
72
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
70
73
archive_dp = resource_dps [0 ]
71
74
images_dp , scenes_dp = Demultiplexer (
72
75
archive_dp ,
@@ -76,12 +79,12 @@ def _make_datapipe(
76
79
buffer_size = INFINITE_BUFFER_SIZE ,
77
80
)
78
81
79
- images_dp = Filter (images_dp , path_comparator ("parent.name" , config . split ))
82
+ images_dp = Filter (images_dp , path_comparator ("parent.name" , self . _split ))
80
83
images_dp = hint_shuffling (images_dp )
81
84
images_dp = hint_sharding (images_dp )
82
85
83
- if config . split != "test" :
84
- scenes_dp = Filter (scenes_dp , path_comparator ("name" , f"CLEVR_{ config . split } _scenes.json" ))
86
+ if self . _split != "test" :
87
+ scenes_dp = Filter (scenes_dp , path_comparator ("name" , f"CLEVR_{ self . _split } _scenes.json" ))
85
88
scenes_dp = JsonParser (scenes_dp )
86
89
scenes_dp = Mapper (scenes_dp , getitem (1 , "scenes" ))
87
90
scenes_dp = UnBatcher (scenes_dp )
@@ -97,3 +100,6 @@ def _make_datapipe(
97
100
dp = Mapper (images_dp , self ._add_empty_anns )
98
101
99
102
return Mapper (dp , self ._prepare_sample )
103
+
104
+ def __len__ (self ) -> int :
105
+ return 70_000 if self ._split == "train" else 15_000
0 commit comments