diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..4e83568ea97 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -899,9 +899,9 @@ def generate(cls, root): return num_samples_map -# @register_mock -def celeba(info, root, config): - return CelebAMockData.generate(root)[config.split] +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def celeba(root, config): + return CelebAMockData.generate(root)[config["split"]] # @register_mock diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 854c705b746..17a42082f3f 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,6 +1,6 @@ import csv -import functools -from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO +import pathlib +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -10,9 +10,7 @@ IterKeyZipper, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, GDriveResource, OnlineResource, ) @@ -25,6 +23,7 @@ ) from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox +from .._api import register_dataset, register_info csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -60,15 +59,32 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: yield line.pop("image_id"), line -class CelebA(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "celeba", - homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", - valid_options=dict(split=("train", "val", "test")), - ) +NAME = "celeba" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict() + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class CelebA(Dataset2): + """ + - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: splits = GDriveResource( "0B7EVK8r0v71pY0NSMzRuSXJEVkk", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", @@ -101,14 +117,13 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [splits, images, identities, attributes, bounding_boxes, landmarks] - _SPLIT_ID_TO_NAME = { - "0": "train", - "1": "val", - "2": "test", - } - - def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: - return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split + def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool: + split_id = { + "train": "0", + "val": "1", + "test": "2", + }[self._split] + return data[1]["split_id"] == split_id def _prepare_sample( self, @@ -145,16 +160,11 @@ def _prepare_sample( }, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) - splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) + splits_dp = Filter(splits_dp, self._filter_split) splits_dp = hint_shuffling(splits_dp) splits_dp = hint_sharding(splits_dp) @@ -186,3 +196,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + "train": 162_770, + "val": 19_867, + "test": 19_962, + }[self._split]