Skip to content

Commit 42bc682

Browse files
authored
migrate coco prototype (#5473)
* migrate coco prototype * revert unrelated change * add kwargs to super constructor call * remove unneeded changes * fix docstring position * make kwargs explicit * add dependencies to docstring * fix missing dependency message
1 parent 2ed549d commit 42bc682

File tree

6 files changed

+118
-72
lines changed

6 files changed

+118
-72
lines changed

test/builtin_dataset_mocks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,15 @@ def generate(
600600
return num_samples
601601

602602

603-
# @register_mock
604-
def coco(info, root, config):
605-
return CocoMockData.generate(root, year=config.year, num_samples=5)
603+
@register_mock(
604+
configs=combinations_grid(
605+
split=("train", "val"),
606+
year=("2017", "2014"),
607+
annotations=("instances", "captions", None),
608+
)
609+
)
610+
def coco(root, config):
611+
return CocoMockData.generate(root, year=config["year"], num_samples=5)
606612

607613

608614
class SBDMockData:

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 82 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import functools
21
import pathlib
32
import re
43
from collections import OrderedDict
5-
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
4+
from collections import defaultdict
5+
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union
66

77
import torch
88
from torchdata.datapipes.iter import (
@@ -16,11 +16,10 @@
1616
UnBatcher,
1717
)
1818
from torchvision.prototype.datasets.utils import (
19-
Dataset,
20-
DatasetConfig,
2119
DatasetInfo,
2220
HttpResource,
2321
OnlineResource,
22+
Dataset2,
2423
)
2524
from torchvision.prototype.datasets.utils._internal import (
2625
MappingIterator,
@@ -32,27 +31,51 @@
3231
hint_shuffling,
3332
)
3433
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
35-
from torchvision.prototype.utils._internal import FrozenMapping
36-
37-
38-
class Coco(Dataset):
39-
def _make_info(self) -> DatasetInfo:
40-
name = "coco"
41-
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
42-
43-
return DatasetInfo(
44-
name,
45-
dependencies=("pycocotools",),
46-
categories=categories,
47-
homepage="https://cocodataset.org/",
48-
valid_options=dict(
49-
split=("train", "val"),
50-
year=("2017", "2014"),
51-
annotations=(*self._ANN_DECODERS.keys(), None),
52-
),
53-
extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))),
34+
35+
from .._api import register_dataset, register_info
36+
37+
38+
NAME = "coco"
39+
40+
41+
@register_info(NAME)
42+
def _info() -> Dict[str, Any]:
43+
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
44+
return dict(categories=categories, super_categories=super_categories)
45+
46+
47+
@register_dataset(NAME)
48+
class Coco(Dataset2):
49+
"""
50+
- **homepage**: https://cocodataset.org/
51+
- **dependencies**:
52+
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
53+
"""
54+
55+
def __init__(
56+
self,
57+
root: Union[str, pathlib.Path],
58+
*,
59+
split: str = "train",
60+
year: str = "2017",
61+
annotations: Optional[str] = "instances",
62+
skip_integrity_check: bool = False,
63+
) -> None:
64+
self._split = self._verify_str_arg(split, "split", {"train", "val"})
65+
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
66+
self._annotations = (
67+
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
68+
if annotations is not None
69+
else None
5470
)
5571

72+
info = _info()
73+
categories, super_categories = info["categories"], info["super_categories"]
74+
self._categories = categories
75+
self._category_to_super_category = dict(zip(categories, super_categories))
76+
77+
super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)
78+
5679
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
5780

5881
_IMAGES_CHECKSUMS = {
@@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo:
6992
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
7093
}
7194

72-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
95+
def _resources(self) -> List[OnlineResource]:
7396
images = HttpResource(
74-
f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip",
75-
sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)],
97+
f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
98+
sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
7699
)
77100
meta = HttpResource(
78-
f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip",
79-
sha256=self._META_CHECKSUMS[config.year],
101+
f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
102+
sha256=self._META_CHECKSUMS[self._year],
80103
)
81104
return [images, meta]
82105

@@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
110133
format="xywh",
111134
image_size=image_size,
112135
),
113-
labels=Label(labels, categories=self.categories),
114-
super_categories=[
115-
self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels
116-
],
136+
labels=Label(labels, categories=self._categories),
137+
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
117138
ann_ids=[ann["id"] for ann in anns],
118139
)
119140

@@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
134155
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
135156
)
136157

137-
def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool:
158+
def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
138159
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
139-
return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations)
160+
return bool(
161+
match
162+
and match["split"] == self._split
163+
and match["year"] == self._year
164+
and match["annotations"] == self._annotations
165+
)
140166

141167
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
142168
key, _ = data
@@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
157183
def _prepare_sample(
158184
self,
159185
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
160-
*,
161-
annotations: str,
162186
) -> Dict[str, Any]:
163187
ann_data, image_data = data
164188
anns, image_meta = ann_data
165189

166190
sample = self._prepare_image(image_data)
191+
# this method is only called if we have annotations
192+
annotations = cast(str, self._annotations)
167193
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
168194
return sample
169195

170-
def _make_datapipe(
171-
self,
172-
resource_dps: List[IterDataPipe],
173-
*,
174-
config: DatasetConfig,
175-
) -> IterDataPipe[Dict[str, Any]]:
196+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
176197
images_dp, meta_dp = resource_dps
177198

178-
if config.annotations is None:
199+
if self._annotations is None:
179200
dp = hint_shuffling(images_dp)
180201
dp = hint_sharding(dp)
202+
dp = hint_shuffling(dp)
181203
return Mapper(dp, self._prepare_image)
182204

183-
meta_dp = Filter(
184-
meta_dp,
185-
functools.partial(
186-
self._filter_meta_files,
187-
split=config.split,
188-
year=config.year,
189-
annotations=config.annotations,
190-
),
191-
)
205+
meta_dp = Filter(meta_dp, self._filter_meta_files)
192206
meta_dp = JsonParser(meta_dp)
193207
meta_dp = Mapper(meta_dp, getitem(1))
194208
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
@@ -216,26 +230,31 @@ def _make_datapipe(
216230
ref_key_fn=getitem("id"),
217231
buffer_size=INFINITE_BUFFER_SIZE,
218232
)
219-
220233
dp = IterKeyZipper(
221234
anns_dp,
222235
images_dp,
223236
key_fn=getitem(1, "file_name"),
224237
ref_key_fn=path_accessor("name"),
225238
buffer_size=INFINITE_BUFFER_SIZE,
226239
)
240+
return Mapper(dp, self._prepare_sample)
241+
242+
def __len__(self) -> int:
243+
return {
244+
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
245+
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
246+
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
247+
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
248+
}[(self._split, self._year)][
249+
self._annotations # type: ignore[index]
250+
]
227251

228-
return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations))
229-
230-
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
231-
config = self.default_config
232-
resources = self.resources(config)
252+
def _generate_categories(self) -> Tuple[Tuple[str, str]]:
253+
self._annotations = "instances"
254+
resources = self._resources()
233255

234-
dp = resources[1].load(root)
235-
dp = Filter(
236-
dp,
237-
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
238-
)
256+
dp = resources[1].load(self._root)
257+
dp = Filter(dp, self._filter_meta_files)
239258
dp = JsonParser(dp)
240259

241260
_, meta = next(iter(dp))

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,17 @@ class ImageNetDemux(enum.IntEnum):
5454

5555
@register_dataset(NAME)
5656
class ImageNet(Dataset2):
57-
def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None:
57+
"""
58+
- **homepage**: https://www.image-net.org/
59+
"""
60+
61+
def __init__(
62+
self,
63+
root: Union[str, pathlib.Path],
64+
*,
65+
split: str = "train",
66+
skip_integrity_check: bool = False,
67+
) -> None:
5868
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
5969

6070
info = _info()
@@ -63,7 +73,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N
6373
self._wnids = wnids
6474
self._wnid_to_category = dict(zip(wnids, categories))
6575

66-
super().__init__(root)
76+
super().__init__(root, skip_integrity_check=skip_integrity_check)
6777

6878
_IMAGES_CHECKSUMS = {
6979
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
split: str = "train",
5151
year: str = "2012",
5252
task: str = "detection",
53-
**kwargs: Any,
53+
skip_integrity_check: bool = False,
5454
) -> None:
5555
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
5656
if split == "test" and year != "2007":
@@ -64,7 +64,7 @@ def __init__(
6464

6565
self._categories = _info()["categories"]
6666

67-
super().__init__(root, **kwargs)
67+
super().__init__(root, skip_integrity_check=skip_integrity_check)
6868

6969
_TRAIN_VAL_ARCHIVES = {
7070
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),

torchvision/prototype/datasets/generate_category_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def parse_args(argv=None):
5151

5252

5353
if __name__ == "__main__":
54-
args = parse_args(["-f", "imagenet"])
54+
args = parse_args()
5555

5656
try:
5757
main(*args.names, force=args.force)

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,18 @@ def _verify_str_arg(
196196
) -> str:
197197
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
198198

199-
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
199+
def __init__(
200+
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
201+
) -> None:
202+
for dependency in dependencies:
203+
try:
204+
importlib.import_module(dependency)
205+
except ModuleNotFoundError:
206+
raise ModuleNotFoundError(
207+
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
208+
f"Please install it, for example with `pip install {dependency}`."
209+
) from None
210+
200211
self._root = pathlib.Path(root).expanduser().resolve()
201212
resources = [
202213
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()

0 commit comments

Comments
 (0)