1
1
import csv
2
2
import functools
3
3
import pathlib
4
- from typing import Any , Dict , List , Optional , Tuple , BinaryIO , Callable
4
+ from typing import Any , Dict , List , Optional , Tuple , BinaryIO , Callable , Union
5
5
6
6
from torchdata .datapipes .iter import (
7
7
IterDataPipe ,
14
14
CSVDictParser ,
15
15
)
16
16
from torchvision .prototype .datasets .utils import (
17
- Dataset ,
18
- DatasetConfig ,
17
+ Dataset2 ,
19
18
DatasetInfo ,
20
19
HttpResource ,
21
20
OnlineResource ,
28
27
getitem ,
29
28
path_comparator ,
30
29
path_accessor ,
30
+ BUILTIN_DIR ,
31
31
)
32
32
from torchvision .prototype .features import Label , BoundingBox , _Feature , EncodedImage
33
33
34
+ from .._api import register_dataset , register_info
35
+
34
36
csv .register_dialect ("cub200" , delimiter = " " )
35
37
36
38
37
- class CUB200 (Dataset ):
38
- def _make_info (self ) -> DatasetInfo :
39
- return DatasetInfo (
40
- "cub200" ,
41
- homepage = "http://www.vision.caltech.edu/visipedia/CUB-200-2011.html" ,
42
- dependencies = ("scipy" ,),
43
- valid_options = dict (
44
- split = ("train" , "test" ),
45
- year = ("2011" , "2010" ),
46
- ),
39
+ NAME = "cub200"
40
+
41
+ CATEGORIES , * _ = zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ NAME } .categories" ))
42
+
43
+
44
+ @register_info (NAME )
45
+ def _info () -> Dict [str , Any ]:
46
+ return dict (categories = CATEGORIES )
47
+
48
+
49
+ @register_dataset (NAME )
50
+ class CUB200 (Dataset2 ):
51
+ """
52
+ - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
53
+ """
54
+
55
+ def __init__ (
56
+ self ,
57
+ root : Union [str , pathlib .Path ],
58
+ * ,
59
+ split : str = "train" ,
60
+ year : str = "2011" ,
61
+ skip_integrity_check : bool = False ,
62
+ ) -> None :
63
+ self ._split = self ._verify_str_arg (split , "split" , ("train" , "test" ))
64
+ self ._year = self ._verify_str_arg (year , "year" , ("2010" , "2011" ))
65
+
66
+ self ._categories = _info ()["categories" ]
67
+
68
+ super ().__init__ (
69
+ root ,
70
+ # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
71
+ # dependencies=("scipy",),
72
+ skip_integrity_check = skip_integrity_check ,
47
73
)
48
74
49
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
50
- if config . year == "2011" :
75
+ def _resources (self ) -> List [OnlineResource ]:
76
+ if self . _year == "2011" :
51
77
archive = HttpResource (
52
78
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" ,
53
79
sha256 = "0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081" ,
@@ -59,7 +85,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
59
85
preprocess = "decompress" ,
60
86
)
61
87
return [archive , segmentations ]
62
- else : # config.year == "2010"
88
+ else : # self._year == "2010"
63
89
split = HttpResource (
64
90
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz" ,
65
91
sha256 = "aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428" ,
@@ -90,12 +116,12 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
90
116
else :
91
117
return None
92
118
93
- def _2011_filter_split (self , row : List [str ], * , split : str ) -> bool :
119
+ def _2011_filter_split (self , row : List [str ]) -> bool :
94
120
_ , split_id = row
95
121
return {
96
122
"0" : "test" ,
97
123
"1" : "train" ,
98
- }[split_id ] == split
124
+ }[split_id ] == self . _split
99
125
100
126
def _2011_segmentation_key (self , data : Tuple [str , Any ]) -> str :
101
127
path = pathlib .Path (data [0 ])
@@ -149,17 +175,12 @@ def _prepare_sample(
149
175
return dict (
150
176
prepare_ann_fn (anns_data , image .image_size ),
151
177
image = image ,
152
- label = Label (int (pathlib .Path (path ).parent .name .rsplit ("." , 1 )[0 ]), categories = self .categories ),
178
+ label = Label (int (pathlib .Path (path ).parent .name .rsplit ("." , 1 )[0 ]), categories = self ._categories ),
153
179
)
154
180
155
- def _make_datapipe (
156
- self ,
157
- resource_dps : List [IterDataPipe ],
158
- * ,
159
- config : DatasetConfig ,
160
- ) -> IterDataPipe [Dict [str , Any ]]:
181
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
161
182
prepare_ann_fn : Callable
162
- if config . year == "2011" :
183
+ if self . _year == "2011" :
163
184
archive_dp , segmentations_dp = resource_dps
164
185
images_dp , split_dp , image_files_dp , bounding_boxes_dp = Demultiplexer (
165
186
archive_dp , 4 , self ._2011_classify_archive , drop_none = True , buffer_size = INFINITE_BUFFER_SIZE
@@ -171,7 +192,7 @@ def _make_datapipe(
171
192
)
172
193
173
194
split_dp = CSVParser (split_dp , dialect = "cub200" )
174
- split_dp = Filter (split_dp , functools . partial ( self ._2011_filter_split , split = config . split ) )
195
+ split_dp = Filter (split_dp , self ._2011_filter_split )
175
196
split_dp = Mapper (split_dp , getitem (0 ))
176
197
split_dp = Mapper (split_dp , image_files_map .get )
177
198
@@ -188,10 +209,10 @@ def _make_datapipe(
188
209
)
189
210
190
211
prepare_ann_fn = self ._2011_prepare_ann
191
- else : # config.year == "2010"
212
+ else : # self._year == "2010"
192
213
split_dp , images_dp , anns_dp = resource_dps
193
214
194
- split_dp = Filter (split_dp , path_comparator ("name" , f"{ config . split } .txt" ))
215
+ split_dp = Filter (split_dp , path_comparator ("name" , f"{ self . _split } .txt" ))
195
216
split_dp = LineReader (split_dp , decode = True , return_path = False )
196
217
split_dp = Mapper (split_dp , self ._2010_split_key )
197
218
@@ -217,11 +238,19 @@ def _make_datapipe(
217
238
)
218
239
return Mapper (dp , functools .partial (self ._prepare_sample , prepare_ann_fn = prepare_ann_fn ))
219
240
220
- def _generate_categories (self , root : pathlib .Path ) -> List [str ]:
221
- config = self .info .make_config (year = "2011" )
222
- resources = self .resources (config )
241
+ def __len__ (self ) -> int :
242
+ return {
243
+ ("train" , "2010" ): 3_000 ,
244
+ ("test" , "2010" ): 3_033 ,
245
+ ("train" , "2011" ): 5_994 ,
246
+ ("test" , "2011" ): 5_794 ,
247
+ }[(self ._split , self ._year )]
248
+
249
+ def _generate_categories (self ) -> List [str ]:
250
+ self ._year = "2011"
251
+ resources = self ._resources ()
223
252
224
- dp = resources [0 ].load (root )
253
+ dp = resources [0 ].load (self . _root )
225
254
dp = Filter (dp , path_comparator ("name" , "classes.txt" ))
226
255
dp = CSVDictParser (dp , fieldnames = ("label" , "category" ), dialect = "cub200" )
227
256
0 commit comments