Skip to content

Commit 6de6ec4

Browse files
NicolasHugpmeier
andauthored
Migrate Food101 prototype dataset (#5758)
* Migrate Food101 dataset * Added length * Update torchvision/prototype/datasets/_builtin/food101.py Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent 2612c4c commit 6de6ec4

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

test/builtin_dataset_mocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,8 @@ def country211(root, config):
933933
return num_examples * len(classes)
934934

935935

936-
# @register_mock
937-
def food101(info, root, config):
936+
@register_mock(configs=combinations_grid(split=("train", "test")))
937+
def food101(root, config):
938938
data_folder = root / "food-101"
939939

940940
num_images_per_class = 3
@@ -968,7 +968,7 @@ def food101(info, root, config):
968968

969969
make_tar(root, f"{data_folder.name}.tar.gz", compression="gz")
970970

971-
return num_samples_map[config.split]
971+
return num_samples_map[config["split"]]
972972

973973

974974
@register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10)))

torchvision/prototype/datasets/_builtin/food101.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Tuple, List, Dict, Optional, BinaryIO
2+
from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union
33

44
from torchdata.datapipes.iter import (
55
IterDataPipe,
@@ -9,26 +9,43 @@
99
Demultiplexer,
1010
IterKeyZipper,
1111
)
12-
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource
12+
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
1313
from torchvision.prototype.datasets.utils._internal import (
1414
hint_shuffling,
15+
BUILTIN_DIR,
1516
hint_sharding,
1617
path_comparator,
1718
getitem,
1819
INFINITE_BUFFER_SIZE,
1920
)
2021
from torchvision.prototype.features import Label, EncodedImage
2122

23+
from .._api import register_dataset, register_info
24+
25+
26+
NAME = "food101"
27+
28+
29+
@register_info(NAME)
30+
def _info() -> Dict[str, Any]:
31+
categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")
32+
categories = [c[0] for c in categories]
33+
return dict(categories=categories)
2234

23-
class Food101(Dataset):
24-
def _make_info(self) -> DatasetInfo:
25-
return DatasetInfo(
26-
"food101",
27-
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
28-
valid_options=dict(split=("train", "test")),
29-
)
3035

31-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
36+
@register_dataset(NAME)
37+
class Food101(Dataset2):
38+
"""Food 101 dataset
39+
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
40+
"""
41+
42+
def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None:
43+
self._split = self._verify_str_arg(split, "split", {"train", "test"})
44+
self._categories = _info()["categories"]
45+
46+
super().__init__(root, skip_integrity_check=skip_integrity_check)
47+
48+
def _resources(self) -> List[OnlineResource]:
3249
return [
3350
HttpResource(
3451
url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz",
@@ -49,7 +66,7 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
4966
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
5067
id, (path, buffer) = data
5168
return dict(
52-
label=Label.from_category(id.split("/", 1)[0], categories=self.categories),
69+
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
5370
path=path,
5471
image=EncodedImage.from_file(buffer),
5572
)
@@ -58,17 +75,12 @@ def _image_key(self, data: Tuple[str, Any]) -> str:
5875
path = Path(data[0])
5976
return path.relative_to(path.parents[1]).with_suffix("").as_posix()
6077

61-
def _make_datapipe(
62-
self,
63-
resource_dps: List[IterDataPipe],
64-
*,
65-
config: DatasetConfig,
66-
) -> IterDataPipe[Dict[str, Any]]:
78+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
6779
archive_dp = resource_dps[0]
6880
images_dp, split_dp = Demultiplexer(
6981
archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
7082
)
71-
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
83+
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
7284
split_dp = LineReader(split_dp, decode=True, return_path=False)
7385
split_dp = hint_sharding(split_dp)
7486
split_dp = hint_shuffling(split_dp)
@@ -83,9 +95,12 @@ def _make_datapipe(
8395

8496
return Mapper(dp, self._prepare_sample)
8597

86-
def _generate_categories(self, root: Path) -> List[str]:
87-
resources = self.resources(self.default_config)
88-
dp = resources[0].load(root)
98+
def _generate_categories(self) -> List[str]:
99+
resources = self.resources()
100+
dp = resources[0].load(self._root)
89101
dp = Filter(dp, path_comparator("name", "classes.txt"))
90102
dp = LineReader(dp, decode=True, return_path=False)
91103
return list(dp)
104+
105+
def __len__(self) -> int:
106+
return 75_750 if self._split == "train" else 25_250

0 commit comments

Comments
 (0)