diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c4f51463e34..ad979b6bd84 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -792,10 +792,23 @@ def generate(cls, root, *, year, trainval): return num_samples_map -# @register_mock -def voc(info, root, config): - trainval = config.split != "test" - return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] +@register_mock( + configs=[ + *combinations_grid( + split=("train", "val", "trainval"), + year=("2007", "2008", "2009", "2010", "2011", "2012"), + task=("detection", "segmentation"), + ), + *combinations_grid( + split=("test",), + year=("2007",), + task=("detection", "segmentation"), + ), + ], +) +def voc(root, config): + trainval = config["split"] != "test" + return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]] class CelebAMockData: diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index fb507af01b0..638878d5ec3 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -59,8 +59,8 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N info = _info() categories, wnids = info["categories"], info["wnids"] - self._categories: List[str] = categories - self._wnids: List[str] = wnids + self._categories = categories + self._wnids = wnids self._wnid_to_category = dict(zip(wnids, categories)) super().__init__(root) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 5c1d3f8c3a3..d000bdbe0e7 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,6 +1,7 @@ +import enum import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union from xml.etree import ElementTree from torchdata.datapipes.iter import ( @@ -12,13 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset2 from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -26,34 +21,50 @@ path_comparator, hint_sharding, hint_shuffling, + BUILTIN_DIR, ) from torchvision.prototype.features import BoundingBox, Label, EncodedImage +from .._api import register_dataset, register_info + +NAME = "voc" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + -class VOCDatasetInfo(DatasetInfo): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007") +@register_dataset(NAME) +class VOC(Dataset2): + """ + - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ + """ - def make_config(self, **options: Any) -> DatasetConfig: - config = super().make_config(**options) - if config.split == "test" and config.year != "2007": + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2012", + task: str = "detection", + **kwargs: Any, + ) -> None: + self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) + if split == "test" and year != "2007": raise ValueError("`split='test'` is only available for `year='2007'`") + else: + self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) + self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) - return config + self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass" + self._split_folder = "Main" if task == "detection" else "Segmentation" + self._categories = _info()["categories"] -class VOC(Dataset): - def _make_info(self) -> DatasetInfo: - return VOCDatasetInfo( - "voc", - homepage="http://host.robots.ox.ac.uk/pascal/VOC/", - valid_options=dict( - split=("train", "val", "trainval", "test"), - year=("2012", "2007", "2008", "2009", "2010", "2011"), - task=("detection", "segmentation"), - ), - ) + super().__init__(root, **kwargs) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), @@ -67,31 +78,27 @@ def _make_info(self) -> DatasetInfo: "2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892") } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year] - archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256) + def _resources(self) -> List[OnlineResource]: + file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year] + archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256) return [archive] - _ANNS_FOLDER = dict( - detection="Annotations", - segmentation="SegmentationClass", - ) - _SPLIT_FOLDER = dict( - detection="Main", - segmentation="Segmentation", - ) - def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: path = pathlib.Path(data[0]) return name in path.parent.parts[-depth:] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + class _Demux(enum.IntEnum): + SPLIT = 0 + IMAGES = 1 + ANNS = 2 + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: if self._is_in_folder(data, name="ImageSets", depth=2): - return 0 + return self._Demux.SPLIT elif self._is_in_folder(data, name="JPEGImages"): - return 1 - elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): - return 2 + return self._Demux.IMAGES + elif self._is_in_folder(data, name=self._anns_folder): + return self._Demux.ANNS else: return None @@ -111,7 +118,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), labels=Label( - [self.categories.index(instance["name"]) for instance in instances], categories=self.categories + [self._categories.index(instance["name"]) for instance in instances], categories=self._categories ), ) @@ -121,8 +128,6 @@ def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], - *, - prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data @@ -130,29 +135,24 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( - prepare_ann_fn(ann_buffer), + (self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer), image_path=image_path, image=EncodedImage.from_file(image_buffer), ann_path=ann_path, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( archive_dp, 3, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task])) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder)) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -166,25 +166,59 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, - functools.partial( - self._prepare_sample, - prepare_ann_fn=self._prepare_detection_ann - if config.task == "detection" - else self._prepare_segmentation_ann, - ), - ) - - def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: - return self._classify_archive(data, config=config) == 2 - - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(task="detection") - - resource = self.resources(config)[0] - dp = resource.load(pathlib.Path(root) / self.name) - dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2007", "detection"): 2_501, + ("train", "2007", "segmentation"): 209, + ("train", "2008", "detection"): 2_111, + ("train", "2008", "segmentation"): 511, + ("train", "2009", "detection"): 3_473, + ("train", "2009", "segmentation"): 749, + ("train", "2010", "detection"): 4_998, + ("train", "2010", "segmentation"): 964, + ("train", "2011", "detection"): 5_717, + ("train", "2011", "segmentation"): 1_112, + ("train", "2012", "detection"): 5_717, + ("train", "2012", "segmentation"): 1_464, + ("val", "2007", "detection"): 2_510, + ("val", "2007", "segmentation"): 213, + ("val", "2008", "detection"): 2_221, + ("val", "2008", "segmentation"): 512, + ("val", "2009", "detection"): 3_581, + ("val", "2009", "segmentation"): 750, + ("val", "2010", "detection"): 5_105, + ("val", "2010", "segmentation"): 964, + ("val", "2011", "detection"): 5_823, + ("val", "2011", "segmentation"): 1_111, + ("val", "2012", "detection"): 5_823, + ("val", "2012", "segmentation"): 1_449, + ("trainval", "2007", "detection"): 5_011, + ("trainval", "2007", "segmentation"): 422, + ("trainval", "2008", "detection"): 4_332, + ("trainval", "2008", "segmentation"): 1_023, + ("trainval", "2009", "detection"): 7_054, + ("trainval", "2009", "segmentation"): 1_499, + ("trainval", "2010", "detection"): 10_103, + ("trainval", "2010", "segmentation"): 1_928, + ("trainval", "2011", "detection"): 11_540, + ("trainval", "2011", "segmentation"): 2_223, + ("trainval", "2012", "detection"): 11_540, + ("trainval", "2012", "segmentation"): 2_913, + ("test", "2007", "detection"): 4_952, + ("test", "2007", "segmentation"): 210, + }[(self._split, self._year, self._task)] + + def _filter_anns(self, data: Tuple[str, Any]) -> bool: + return self._classify_archive(data) == self._Demux.ANNS + + def _generate_categories(self) -> List[str]: + self._task = "detection" + resources = self._resources() + + archive_dp = resources[0].load(self._root) + dp = Filter(archive_dp, self._filter_detection_anns) dp = Mapper(dp, self._parse_detection_ann, input_col=1) return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})