12
12
import tempfile
13
13
import unittest .mock
14
14
import xml .etree .ElementTree as ET
15
- from collections import defaultdict , Counter , UserDict
15
+ from collections import defaultdict , Counter
16
16
17
17
import numpy as np
18
18
import PIL .Image
34
34
__all__ = ["DATASET_MOCKS" , "parametrize_dataset_mocks" ]
35
35
36
36
37
- class ResourceMock (datasets .utils .OnlineResource ):
38
- def __init__ (self , * , dataset_name , dataset_config , ** kwargs ):
39
- super ().__init__ (** kwargs )
40
- self .dataset_name = dataset_name
41
- self .dataset_config = dataset_config
42
-
43
- def _download (self , _ ):
44
- raise pytest .UsageError (
45
- f"Dataset '{ self .dataset_name } ' requires the file '{ self .file_name } ' for { self .dataset_config } , "
46
- f"but this file does not exist."
47
- )
48
-
49
-
50
37
class DatasetMock :
51
- def __init__ (self , name , mock_data_fn , * , configs = None ):
38
+ def __init__ (self , name , mock_data_fn ):
52
39
self .dataset = find (name )
40
+ self .info = self .dataset .info
41
+ self .name = self .info .name
42
+
53
43
self .root = TEST_HOME / self .dataset .name
54
44
self .mock_data_fn = mock_data_fn
55
- self .configs = configs or self .info ._configs
45
+ self .configs = self .info ._configs
56
46
self ._cache = {}
57
47
58
- @property
59
- def info (self ):
60
- return self .dataset .info
61
-
62
- @property
63
- def name (self ):
64
- return self .info .name
65
-
66
48
def _parse_mock_data (self , config , mock_infos ):
67
49
if mock_infos is None :
68
50
raise pytest .UsageError (
@@ -79,7 +61,7 @@ def _parse_mock_data(self, config, mock_infos):
79
61
f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
80
62
)
81
63
82
- for config_ , mock_info in list ( mock_infos .items () ):
64
+ for config_ , mock_info in mock_infos .items ():
83
65
if config_ in self ._cache :
84
66
raise pytest .UsageError (
85
67
f"The mock info for config { config_ } of dataset { self .name } generated for config { config } "
@@ -103,7 +85,7 @@ def _parse_mock_data(self, config, mock_infos):
103
85
return mock_infos
104
86
105
87
def _prepare_resources (self , config ):
106
- with contextlib . suppress ( KeyError ) :
88
+ if config in self . _cache :
107
89
return self ._cache [config ]
108
90
109
91
self .root .mkdir (exist_ok = True )
@@ -146,8 +128,6 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
146
128
for mock in dataset_mocks :
147
129
if isinstance (mock , DatasetMock ):
148
130
mocks [mock .name ] = mock
149
- elif isinstance (mock , collections .abc .Sequence ):
150
- mocks .update ({mock_ .name : mock_ for mock_ in mock })
151
131
elif isinstance (mock , collections .abc .Mapping ):
152
132
mocks .update (mock )
153
133
else :
@@ -173,14 +153,13 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
173
153
)
174
154
175
155
176
- class DatasetMocks (UserDict ):
177
- def set_from_named_callable (self , fn ):
178
- name = fn .__name__ .replace ("_" , "-" )
179
- self .data [name ] = DatasetMock (name , fn )
180
- return fn
156
+ DATASET_MOCKS = {}
181
157
182
158
183
- DATASET_MOCKS = DatasetMocks ()
159
+ def register_mock (fn ):
160
+ name = fn .__name__ .replace ("_" , "-" )
161
+ DATASET_MOCKS [name ] = DatasetMock (name , fn )
162
+ return fn
184
163
185
164
186
165
class MNISTMockData :
@@ -258,7 +237,7 @@ def generate(
258
237
return num_samples
259
238
260
239
261
- @DATASET_MOCKS . set_from_named_callable
240
+ @register_mock
262
241
def mnist (info , root , config ):
263
242
train = config .split == "train"
264
243
images_file = f"{ 'train' if train else 't10k' } -images-idx3-ubyte.gz"
@@ -274,7 +253,7 @@ def mnist(info, root, config):
274
253
DATASET_MOCKS .update ({name : DatasetMock (name , mnist ) for name in ["fashionmnist" , "kmnist" ]})
275
254
276
255
277
- @DATASET_MOCKS . set_from_named_callable
256
+ @register_mock
278
257
def emnist (info , root , _ ):
279
258
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
280
259
# labels in the data files. Thus, num_categories != len(categories) there.
@@ -303,7 +282,7 @@ def emnist(info, root, _):
303
282
return mock_infos
304
283
305
284
306
- @DATASET_MOCKS . set_from_named_callable
285
+ @register_mock
307
286
def qmnist (info , root , config ):
308
287
num_categories = len (info .categories )
309
288
if config .split == "train" :
@@ -382,7 +361,7 @@ def generate(
382
361
make_tar (root , name , folder , compression = "gz" )
383
362
384
363
385
- @DATASET_MOCKS . set_from_named_callable
364
+ @register_mock
386
365
def cifar10 (info , root , config ):
387
366
train_files = [f"data_batch_{ idx } " for idx in range (1 , 6 )]
388
367
test_files = ["test_batch" ]
@@ -400,7 +379,7 @@ def cifar10(info, root, config):
400
379
return len (train_files if config .split == "train" else test_files )
401
380
402
381
403
- @DATASET_MOCKS . set_from_named_callable
382
+ @register_mock
404
383
def cifar100 (info , root , config ):
405
384
train_files = ["train" ]
406
385
test_files = ["test" ]
@@ -418,7 +397,7 @@ def cifar100(info, root, config):
418
397
return len (train_files if config .split == "train" else test_files )
419
398
420
399
421
- @DATASET_MOCKS . set_from_named_callable
400
+ @register_mock
422
401
def caltech101 (info , root , config ):
423
402
def create_ann_file (root , name ):
424
403
import scipy .io
@@ -468,7 +447,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
468
447
return num_images_per_category * len (info .categories )
469
448
470
449
471
- @DATASET_MOCKS . set_from_named_callable
450
+ @register_mock
472
451
def caltech256 (info , root , config ):
473
452
dir = root / "256_ObjectCategories"
474
453
num_images_per_category = 2
@@ -488,7 +467,7 @@ def caltech256(info, root, config):
488
467
return num_images_per_category * len (info .categories )
489
468
490
469
491
- @DATASET_MOCKS . set_from_named_callable
470
+ @register_mock
492
471
def imagenet (info , root , config ):
493
472
wnids = tuple (info .extra .wnid_to_category .keys ())
494
473
if config .split == "train" :
@@ -643,7 +622,7 @@ def generate(
643
622
return num_samples
644
623
645
624
646
- @DATASET_MOCKS . set_from_named_callable
625
+ @register_mock
647
626
def coco (info , root , config ):
648
627
return dict (
649
628
zip (
@@ -722,13 +701,13 @@ def generate(cls, root):
722
701
return num_samples_map
723
702
724
703
725
- @DATASET_MOCKS . set_from_named_callable
704
+ @register_mock
726
705
def sbd (info , root , _ ):
727
706
num_samples_map = SBDMockData .generate (root )
728
707
return {config : num_samples_map [config .split ] for config in info ._configs }
729
708
730
709
731
- @DATASET_MOCKS . set_from_named_callable
710
+ @register_mock
732
711
def semeion (info , root , config ):
733
712
num_samples = 3
734
713
@@ -839,7 +818,7 @@ def generate(cls, root, *, year, trainval):
839
818
return num_samples_map
840
819
841
820
842
- @DATASET_MOCKS . set_from_named_callable
821
+ @register_mock
843
822
def voc (info , root , config ):
844
823
trainval = config .split != "test"
845
824
num_samples_map = VOCMockData .generate (root , year = config .year , trainval = trainval )
@@ -938,13 +917,13 @@ def generate(cls, root):
938
917
return num_samples_map
939
918
940
919
941
- @DATASET_MOCKS . set_from_named_callable
920
+ @register_mock
942
921
def celeba (info , root , _ ):
943
922
num_samples_map = CelebAMockData .generate (root )
944
923
return {config : num_samples_map [config .split ] for config in info ._configs }
945
924
946
925
947
- @DATASET_MOCKS . set_from_named_callable
926
+ @register_mock
948
927
def dtd (info , root , _ ):
949
928
data_folder = root / "dtd"
950
929
@@ -992,7 +971,7 @@ def dtd(info, root, _):
992
971
return num_samples_map
993
972
994
973
995
- @DATASET_MOCKS . set_from_named_callable
974
+ @register_mock
996
975
def fer2013 (info , root , config ):
997
976
num_samples = 5 if config .split == "train" else 3
998
977
@@ -1017,7 +996,7 @@ def fer2013(info, root, config):
1017
996
return num_samples
1018
997
1019
998
1020
- @DATASET_MOCKS . set_from_named_callable
999
+ @register_mock
1021
1000
def gtsrb (info , root , config ):
1022
1001
num_examples_per_class = 5 if config .split == "train" else 3
1023
1002
classes = ("00000" , "00042" , "00012" )
@@ -1087,7 +1066,7 @@ def _make_ann_file(path, num_examples, class_idx):
1087
1066
return num_examples
1088
1067
1089
1068
1090
- @DATASET_MOCKS . set_from_named_callable
1069
+ @register_mock
1091
1070
def clevr (info , root , config ):
1092
1071
data_folder = root / "CLEVR_v1.0"
1093
1072
@@ -1193,7 +1172,7 @@ def generate(self, root):
1193
1172
return num_samples_map
1194
1173
1195
1174
1196
- @DATASET_MOCKS . set_from_named_callable
1175
+ @register_mock
1197
1176
def oxford_iiit_pet (info , root , config ):
1198
1177
num_samples_map = OxfordIIITPetMockData .generate (root )
1199
1178
return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs }
@@ -1360,13 +1339,13 @@ def generate(cls, root):
1360
1339
return num_samples_map
1361
1340
1362
1341
1363
- @DATASET_MOCKS . set_from_named_callable
1342
+ @register_mock
1364
1343
def cub200 (info , root , config ):
1365
1344
num_samples_map = (CUB2002011MockData if config .year == "2011" else CUB2002010MockData ).generate (root )
1366
1345
return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs if config_ .year == config .year }
1367
1346
1368
1347
1369
- @DATASET_MOCKS . set_from_named_callable
1348
+ @register_mock
1370
1349
def svhn (info , root , config ):
1371
1350
import scipy .io as sio
1372
1351
0 commit comments