@@ -45,14 +45,32 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
45
45
46
46
47
47
class _CifarBase (Dataset ):
48
+ _FILE_NAME : str
49
+ _SHA256 : str
48
50
_LABELS_KEY : str
49
51
_META_FILE_NAME : str
50
52
_CATEGORIES_KEY : str
51
53
52
54
@abc .abstractmethod
53
- def _is_data_file (self , data : Tuple [str , io .IOBase ], * , config : DatasetConfig ) -> Optional [int ]:
55
+ def _is_data_file (self , data : Tuple [str , io .IOBase ], * , split : str ) -> Optional [int ]:
54
56
pass
55
57
58
+ def _make_info (self ) -> DatasetInfo :
59
+ return DatasetInfo (
60
+ type (self ).__name__ .lower (),
61
+ type = DatasetType .RAW ,
62
+ homepage = "https://www.cs.toronto.edu/~kriz/cifar.html" ,
63
+ valid_options = dict (split = ("train" , "test" )),
64
+ )
65
+
66
+ def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
67
+ return [
68
+ HttpResource (
69
+ f"https://www.cs.toronto.edu/~kriz/{ self ._FILE_NAME } " ,
70
+ sha256 = self ._SHA256 ,
71
+ )
72
+ ]
73
+
56
74
def _unpickle (self , data : Tuple [str , io .BytesIO ]) -> Dict [str , Any ]:
57
75
_ , file = data
58
76
return cast (Dict [str , Any ], pickle .load (file , encoding = "latin1" ))
@@ -84,7 +102,7 @@ def _make_datapipe(
84
102
decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
85
103
) -> IterDataPipe [Dict [str , Any ]]:
86
104
dp = resource_dps [0 ]
87
- dp = Filter (dp , functools .partial (self ._is_data_file , config = config ))
105
+ dp = Filter (dp , functools .partial (self ._is_data_file , split = config . split ))
88
106
dp = Mapper (dp , self ._unpickle )
89
107
dp = CifarFileReader (dp , labels_key = self ._LABELS_KEY )
90
108
dp = hint_sharding (dp )
@@ -102,53 +120,24 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]:
102
120
103
121
104
122
class Cifar10 (_CifarBase ):
123
+ _FILE_NAME = "cifar-10-python.tar.gz"
124
+ _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
105
125
_LABELS_KEY = "labels"
106
126
_META_FILE_NAME = "batches.meta"
107
127
_CATEGORIES_KEY = "label_names"
108
128
109
- def _is_data_file (self , data : Tuple [str , Any ], * , config : DatasetConfig ) -> bool :
129
+ def _is_data_file (self , data : Tuple [str , Any ], * , split : str ) -> bool :
110
130
path = pathlib .Path (data [0 ])
111
- return path .name .startswith ("data" if config .split == "train" else "test" )
112
-
113
- def _make_info (self ) -> DatasetInfo :
114
- return DatasetInfo (
115
- "cifar10" ,
116
- type = DatasetType .RAW ,
117
- homepage = "https://www.cs.toronto.edu/~kriz/cifar.html" ,
118
- )
119
-
120
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
121
- return [
122
- HttpResource (
123
- "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" ,
124
- sha256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" ,
125
- )
126
- ]
131
+ return path .name .startswith ("data" if split == "train" else "test" )
127
132
128
133
129
134
class Cifar100 (_CifarBase ):
135
+ _FILE_NAME = "cifar-100-python.tar.gz"
136
+ _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
130
137
_LABELS_KEY = "fine_labels"
131
138
_META_FILE_NAME = "meta"
132
139
_CATEGORIES_KEY = "fine_label_names"
133
140
134
- def _is_data_file (self , data : Tuple [str , Any ], * , config : DatasetConfig ) -> bool :
141
+ def _is_data_file (self , data : Tuple [str , Any ], * , split : str ) -> bool :
135
142
path = pathlib .Path (data [0 ])
136
- return path .name == cast (str , config .split )
137
-
138
- def _make_info (self ) -> DatasetInfo :
139
- return DatasetInfo (
140
- "cifar100" ,
141
- type = DatasetType .RAW ,
142
- homepage = "https://www.cs.toronto.edu/~kriz/cifar.html" ,
143
- valid_options = dict (
144
- split = ("train" , "test" ),
145
- ),
146
- )
147
-
148
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
149
- return [
150
- HttpResource (
151
- "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" ,
152
- sha256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" ,
153
- )
154
- ]
143
+ return path .name == split
0 commit comments