1
1
from pathlib import Path
2
- from typing import Any , Tuple , List , Dict , Optional , BinaryIO
2
+ from typing import Any , Tuple , List , Dict , Optional , BinaryIO , Union
3
3
4
4
from torchdata .datapipes .iter import (
5
5
IterDataPipe ,
9
9
Demultiplexer ,
10
10
IterKeyZipper ,
11
11
)
12
- from torchvision .prototype .datasets .utils import Dataset , DatasetInfo , DatasetConfig , HttpResource , OnlineResource
12
+ from torchvision .prototype .datasets .utils import Dataset2 , DatasetInfo , HttpResource , OnlineResource
13
13
from torchvision .prototype .datasets .utils ._internal import (
14
14
hint_shuffling ,
15
+ BUILTIN_DIR ,
15
16
hint_sharding ,
16
17
path_comparator ,
17
18
getitem ,
18
19
INFINITE_BUFFER_SIZE ,
19
20
)
20
21
from torchvision .prototype .features import Label , EncodedImage
21
22
23
+ from .._api import register_dataset , register_info
24
+
25
+
26
+ NAME = "food101"
27
+
28
+
29
+ @register_info (NAME )
30
+ def _info () -> Dict [str , Any ]:
31
+ categories = DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ NAME } .categories" )
32
+ categories = [c [0 ] for c in categories ]
33
+ return dict (categories = categories )
22
34
23
- class Food101 (Dataset ):
24
- def _make_info (self ) -> DatasetInfo :
25
- return DatasetInfo (
26
- "food101" ,
27
- homepage = "https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101" ,
28
- valid_options = dict (split = ("train" , "test" )),
29
- )
30
35
31
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
36
+ @register_dataset (NAME )
37
+ class Food101 (Dataset2 ):
38
+ """Food 101 dataset
39
+ homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
40
+ """
41
+
42
+ def __init__ (self , root : Union [str , Path ], * , split : str = "train" , skip_integrity_check : bool = False ) -> None :
43
+ self ._split = self ._verify_str_arg (split , "split" , {"train" , "test" })
44
+ self ._categories = _info ()["categories" ]
45
+
46
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check )
47
+
48
+ def _resources (self ) -> List [OnlineResource ]:
32
49
return [
33
50
HttpResource (
34
51
url = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" ,
@@ -49,7 +66,7 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
49
66
def _prepare_sample (self , data : Tuple [str , Tuple [str , BinaryIO ]]) -> Dict [str , Any ]:
50
67
id , (path , buffer ) = data
51
68
return dict (
52
- label = Label .from_category (id .split ("/" , 1 )[0 ], categories = self .categories ),
69
+ label = Label .from_category (id .split ("/" , 1 )[0 ], categories = self ._categories ),
53
70
path = path ,
54
71
image = EncodedImage .from_file (buffer ),
55
72
)
@@ -58,17 +75,12 @@ def _image_key(self, data: Tuple[str, Any]) -> str:
58
75
path = Path (data [0 ])
59
76
return path .relative_to (path .parents [1 ]).with_suffix ("" ).as_posix ()
60
77
61
- def _make_datapipe (
62
- self ,
63
- resource_dps : List [IterDataPipe ],
64
- * ,
65
- config : DatasetConfig ,
66
- ) -> IterDataPipe [Dict [str , Any ]]:
78
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
67
79
archive_dp = resource_dps [0 ]
68
80
images_dp , split_dp = Demultiplexer (
69
81
archive_dp , 2 , self ._classify_archive , drop_none = True , buffer_size = INFINITE_BUFFER_SIZE
70
82
)
71
- split_dp = Filter (split_dp , path_comparator ("name" , f"{ config . split } .txt" ))
83
+ split_dp = Filter (split_dp , path_comparator ("name" , f"{ self . _split } .txt" ))
72
84
split_dp = LineReader (split_dp , decode = True , return_path = False )
73
85
split_dp = hint_sharding (split_dp )
74
86
split_dp = hint_shuffling (split_dp )
@@ -83,9 +95,12 @@ def _make_datapipe(
83
95
84
96
return Mapper (dp , self ._prepare_sample )
85
97
86
- def _generate_categories (self , root : Path ) -> List [str ]:
87
- resources = self .resources (self . default_config )
88
- dp = resources [0 ].load (root )
98
+ def _generate_categories (self ) -> List [str ]:
99
+ resources = self .resources ()
100
+ dp = resources [0 ].load (self . _root )
89
101
dp = Filter (dp , path_comparator ("name" , "classes.txt" ))
90
102
dp = LineReader (dp , decode = True , return_path = False )
91
103
return list (dp )
104
+
105
+ def __len__ (self ) -> int :
106
+ return 75_750 if self ._split == "train" else 25_250
0 commit comments