1
- import functools
2
1
import pathlib
3
2
import re
4
3
from collections import OrderedDict
5
- from typing import Any , Dict , List , Optional , Tuple , cast , BinaryIO
4
+ from collections import defaultdict
5
+ from typing import Any , Dict , List , Optional , Tuple , cast , BinaryIO , Union
6
6
7
7
import torch
8
8
from torchdata .datapipes .iter import (
16
16
UnBatcher ,
17
17
)
18
18
from torchvision .prototype .datasets .utils import (
19
- Dataset ,
20
- DatasetConfig ,
21
19
DatasetInfo ,
22
20
HttpResource ,
23
21
OnlineResource ,
22
+ Dataset2 ,
24
23
)
25
24
from torchvision .prototype .datasets .utils ._internal import (
26
25
MappingIterator ,
32
31
hint_shuffling ,
33
32
)
34
33
from torchvision .prototype .features import BoundingBox , Label , _Feature , EncodedImage
35
- from torchvision .prototype .utils ._internal import FrozenMapping
36
-
37
-
38
- class Coco (Dataset ):
39
- def _make_info (self ) -> DatasetInfo :
40
- name = "coco"
41
- categories , super_categories = zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ name } .categories" ))
42
-
43
- return DatasetInfo (
44
- name ,
45
- dependencies = ("pycocotools" ,),
46
- categories = categories ,
47
- homepage = "https://cocodataset.org/" ,
48
- valid_options = dict (
49
- split = ("train" , "val" ),
50
- year = ("2017" , "2014" ),
51
- annotations = (* self ._ANN_DECODERS .keys (), None ),
52
- ),
53
- extra = dict (category_to_super_category = FrozenMapping (zip (categories , super_categories ))),
34
+
35
+ from .._api import register_dataset , register_info
36
+
37
+
38
+ NAME = "coco"
39
+
40
+
41
+ @register_info (NAME )
42
+ def _info () -> Dict [str , Any ]:
43
+ categories , super_categories = zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ NAME } .categories" ))
44
+ return dict (categories = categories , super_categories = super_categories )
45
+
46
+
47
+ @register_dataset (NAME )
48
+ class Coco (Dataset2 ):
49
+ """
50
+ - **homepage**: https://cocodataset.org/
51
+ - **dependencies**:
52
+ - <pycocotools `https://github.com/cocodataset/cocoapi`>_
53
+ """
54
+
55
+ def __init__ (
56
+ self ,
57
+ root : Union [str , pathlib .Path ],
58
+ * ,
59
+ split : str = "train" ,
60
+ year : str = "2017" ,
61
+ annotations : Optional [str ] = "instances" ,
62
+ skip_integrity_check : bool = False ,
63
+ ) -> None :
64
+ self ._split = self ._verify_str_arg (split , "split" , {"train" , "val" })
65
+ self ._year = self ._verify_str_arg (year , "year" , {"2017" , "2014" })
66
+ self ._annotations = (
67
+ self ._verify_str_arg (annotations , "annotations" , self ._ANN_DECODERS .keys ())
68
+ if annotations is not None
69
+ else None
54
70
)
55
71
72
+ info = _info ()
73
+ categories , super_categories = info ["categories" ], info ["super_categories" ]
74
+ self ._categories = categories
75
+ self ._category_to_super_category = dict (zip (categories , super_categories ))
76
+
77
+ super ().__init__ (root , dependencies = ("pycocotools" ,), skip_integrity_check = skip_integrity_check )
78
+
56
79
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
57
80
58
81
_IMAGES_CHECKSUMS = {
@@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo:
69
92
"2017" : "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268" ,
70
93
}
71
94
72
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
95
+ def _resources (self ) -> List [OnlineResource ]:
73
96
images = HttpResource (
74
- f"{ self ._IMAGE_URL_BASE } /{ config . split } { config . year } .zip" ,
75
- sha256 = self ._IMAGES_CHECKSUMS [(config . year , config . split )],
97
+ f"{ self ._IMAGE_URL_BASE } /{ self . _split } { self . _year } .zip" ,
98
+ sha256 = self ._IMAGES_CHECKSUMS [(self . _year , self . _split )],
76
99
)
77
100
meta = HttpResource (
78
- f"{ self ._META_URL_BASE } /annotations_trainval{ config . year } .zip" ,
79
- sha256 = self ._META_CHECKSUMS [config . year ],
101
+ f"{ self ._META_URL_BASE } /annotations_trainval{ self . _year } .zip" ,
102
+ sha256 = self ._META_CHECKSUMS [self . _year ],
80
103
)
81
104
return [images , meta ]
82
105
@@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
110
133
format = "xywh" ,
111
134
image_size = image_size ,
112
135
),
113
- labels = Label (labels , categories = self .categories ),
114
- super_categories = [
115
- self .info .extra .category_to_super_category [self .info .categories [label ]] for label in labels
116
- ],
136
+ labels = Label (labels , categories = self ._categories ),
137
+ super_categories = [self ._category_to_super_category [self ._categories [label ]] for label in labels ],
117
138
ann_ids = [ann ["id" ] for ann in anns ],
118
139
)
119
140
@@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
134
155
fr"(?P<annotations>({ '|' .join (_ANN_DECODERS .keys ())} ))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
135
156
)
136
157
137
- def _filter_meta_files (self , data : Tuple [str , Any ], * , split : str , year : str , annotations : str ) -> bool :
158
+ def _filter_meta_files (self , data : Tuple [str , Any ]) -> bool :
138
159
match = self ._META_FILE_PATTERN .match (pathlib .Path (data [0 ]).name )
139
- return bool (match and match ["split" ] == split and match ["year" ] == year and match ["annotations" ] == annotations )
160
+ return bool (
161
+ match
162
+ and match ["split" ] == self ._split
163
+ and match ["year" ] == self ._year
164
+ and match ["annotations" ] == self ._annotations
165
+ )
140
166
141
167
def _classify_meta (self , data : Tuple [str , Any ]) -> Optional [int ]:
142
168
key , _ = data
@@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
157
183
def _prepare_sample (
158
184
self ,
159
185
data : Tuple [Tuple [List [Dict [str , Any ]], Dict [str , Any ]], Tuple [str , BinaryIO ]],
160
- * ,
161
- annotations : str ,
162
186
) -> Dict [str , Any ]:
163
187
ann_data , image_data = data
164
188
anns , image_meta = ann_data
165
189
166
190
sample = self ._prepare_image (image_data )
191
+ # this method is only called if we have annotations
192
+ annotations = cast (str , self ._annotations )
167
193
sample .update (self ._ANN_DECODERS [annotations ](self , anns , image_meta ))
168
194
return sample
169
195
170
- def _make_datapipe (
171
- self ,
172
- resource_dps : List [IterDataPipe ],
173
- * ,
174
- config : DatasetConfig ,
175
- ) -> IterDataPipe [Dict [str , Any ]]:
196
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
176
197
images_dp , meta_dp = resource_dps
177
198
178
- if config . annotations is None :
199
+ if self . _annotations is None :
179
200
dp = hint_shuffling (images_dp )
180
201
dp = hint_sharding (dp )
202
+ dp = hint_shuffling (dp )
181
203
return Mapper (dp , self ._prepare_image )
182
204
183
- meta_dp = Filter (
184
- meta_dp ,
185
- functools .partial (
186
- self ._filter_meta_files ,
187
- split = config .split ,
188
- year = config .year ,
189
- annotations = config .annotations ,
190
- ),
191
- )
205
+ meta_dp = Filter (meta_dp , self ._filter_meta_files )
192
206
meta_dp = JsonParser (meta_dp )
193
207
meta_dp = Mapper (meta_dp , getitem (1 ))
194
208
meta_dp : IterDataPipe [Dict [str , Dict [str , Any ]]] = MappingIterator (meta_dp )
@@ -216,26 +230,31 @@ def _make_datapipe(
216
230
ref_key_fn = getitem ("id" ),
217
231
buffer_size = INFINITE_BUFFER_SIZE ,
218
232
)
219
-
220
233
dp = IterKeyZipper (
221
234
anns_dp ,
222
235
images_dp ,
223
236
key_fn = getitem (1 , "file_name" ),
224
237
ref_key_fn = path_accessor ("name" ),
225
238
buffer_size = INFINITE_BUFFER_SIZE ,
226
239
)
240
+ return Mapper (dp , self ._prepare_sample )
241
+
242
+ def __len__ (self ) -> int :
243
+ return {
244
+ ("train" , "2017" ): defaultdict (lambda : 118_287 , instances = 117_266 ),
245
+ ("train" , "2014" ): defaultdict (lambda : 82_783 , instances = 82_081 ),
246
+ ("val" , "2017" ): defaultdict (lambda : 5_000 , instances = 4_952 ),
247
+ ("val" , "2014" ): defaultdict (lambda : 40_504 , instances = 40_137 ),
248
+ }[(self ._split , self ._year )][
249
+ self ._annotations # type: ignore[index]
250
+ ]
227
251
228
- return Mapper (dp , functools .partial (self ._prepare_sample , annotations = config .annotations ))
229
-
230
- def _generate_categories (self , root : pathlib .Path ) -> Tuple [Tuple [str , str ]]:
231
- config = self .default_config
232
- resources = self .resources (config )
252
+ def _generate_categories (self ) -> Tuple [Tuple [str , str ]]:
253
+ self ._annotations = "instances"
254
+ resources = self ._resources ()
233
255
234
- dp = resources [1 ].load (root )
235
- dp = Filter (
236
- dp ,
237
- functools .partial (self ._filter_meta_files , split = config .split , year = config .year , annotations = "instances" ),
238
- )
256
+ dp = resources [1 ].load (self ._root )
257
+ dp = Filter (dp , self ._filter_meta_files )
239
258
dp = JsonParser (dp )
240
259
241
260
_ , meta = next (iter (dp ))
0 commit comments