From d6d7140d81ccd6ef8ab946177b96125af47aa5e4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 09:42:35 +0200 Subject: [PATCH] migrate country211 prototype dataset --- test/builtin_dataset_mocks.py | 11 +-- .../prototype/datasets/_builtin/country211.py | 72 ++++++++++++------- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..a8e0f0bc96a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -904,14 +904,9 @@ def celeba(info, root, config): return CelebAMockData.generate(root)[config.split] -# @register_mock -def country211(info, root, config): - split_name_mapper = { - "train": "train", - "val": "valid", - "test": "test", - } - split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]]) +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def country211(root, config): + split_folder = pathlib.Path(root, "country211", "valid" if config["split"] == "val" else config["split"]) split_folder.mkdir(parents=True, exist_ok=True) num_examples = { diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 0b4dc306734..ae0564b224b 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -1,21 +1,44 @@ import pathlib -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR from torchvision.prototype.features import EncodedImage, Label +from .._api import register_dataset, register_info -class Country211(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "country211", - homepage="https://github.com/openai/CLIP/blob/main/data/country211.md", - valid_options=dict(split=("train", "val", "test")), - ) +NAME = "country211" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + + +@register_dataset(NAME) +class Country211(Dataset2): + """ + - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md + """ - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + self._split_folder_name = "valid" if split == "val" else split + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "https://openaipublic.azureedge.net/clip/data/country211.tgz", @@ -23,17 +46,11 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) ] - _SPLIT_NAME_MAPPER = { - "train": "train", - "val": "valid", - "test": "test", - } - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) @@ -41,16 +58,21 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: return pathlib.Path(data[0]).parent.parent.name == split - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) + dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name)) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) - dp = resources[0].load(root) + def __len__(self) -> int: + return { + "train": 31_650, + "val": 10_550, + "test": 21_100, + }[self._split] + + def _generate_categories(self) -> List[str]: + resources = self.resources() + dp = resources[0].load(self.root) return sorted({pathlib.Path(path).parent.name for path, _ in dp})