7
7
import torch
8
8
from builtin_dataset_mocks import parametrize_dataset_mocks , DATASET_MOCKS
9
9
from torch .testing ._comparison import assert_equal , TensorLikePair , ObjectPair
10
+ from torch .utils .data import DataLoader
10
11
from torch .utils .data .graph import traverse
11
12
from torch .utils .data .graph_settings import get_all_graph_pipes
12
- from torchdata .datapipes .iter import IterDataPipe , Shuffler , ShardingFilter
13
+ from torchdata .datapipes .iter import Shuffler , ShardingFilter
13
14
from torchvision ._utils import sequence_to_str
14
15
from torchvision .prototype import transforms , datasets
15
16
from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE
@@ -42,14 +43,24 @@ def test_coverage():
42
43
43
44
@pytest .mark .filterwarnings ("error" )
44
45
class TestCommon :
46
+ @pytest .mark .parametrize ("name" , datasets .list_datasets ())
47
+ def test_info (self , name ):
48
+ try :
49
+ info = datasets .info (name )
50
+ except ValueError :
51
+ raise AssertionError ("No info available." ) from None
52
+
53
+ if not (isinstance (info , dict ) and all (isinstance (key , str ) for key in info .keys ())):
54
+ raise AssertionError ("Info should be a dictionary with string keys." )
55
+
45
56
@parametrize_dataset_mocks (DATASET_MOCKS )
46
57
def test_smoke (self , test_home , dataset_mock , config ):
47
58
dataset_mock .prepare (test_home , config )
48
59
49
60
dataset = datasets .load (dataset_mock .name , ** config )
50
61
51
- if not isinstance (dataset , IterDataPipe ):
52
- raise AssertionError (f"Loading the dataset should return an IterDataPipe , but got { type (dataset )} instead." )
62
+ if not isinstance (dataset , datasets . utils . Dataset ):
63
+ raise AssertionError (f"Loading the dataset should return an Dataset , but got { type (dataset )} instead." )
53
64
54
65
@parametrize_dataset_mocks (DATASET_MOCKS )
55
66
def test_sample (self , test_home , dataset_mock , config ):
@@ -76,24 +87,7 @@ def test_num_samples(self, test_home, dataset_mock, config):
76
87
77
88
dataset = datasets .load (dataset_mock .name , ** config )
78
89
79
- num_samples = 0
80
- for _ in dataset :
81
- num_samples += 1
82
-
83
- assert num_samples == mock_info ["num_samples" ]
84
-
85
- @parametrize_dataset_mocks (DATASET_MOCKS )
86
- def test_decoding (self , test_home , dataset_mock , config ):
87
- dataset_mock .prepare (test_home , config )
88
-
89
- dataset = datasets .load (dataset_mock .name , ** config )
90
-
91
- undecoded_features = {key for key , value in next (iter (dataset )).items () if isinstance (value , io .IOBase )}
92
- if undecoded_features :
93
- raise AssertionError (
94
- f"The values of key(s) "
95
- f"{ sequence_to_str (sorted (undecoded_features ), separate_last = 'and ' )} were not decoded."
96
- )
90
+ assert len (list (dataset )) == mock_info ["num_samples" ]
97
91
98
92
@parametrize_dataset_mocks (DATASET_MOCKS )
99
93
def test_no_vanilla_tensors (self , test_home , dataset_mock , config ):
@@ -116,14 +110,36 @@ def test_transformable(self, test_home, dataset_mock, config):
116
110
117
111
next (iter (dataset .map (transforms .Identity ())))
118
112
113
+ @pytest .mark .parametrize ("only_datapipe" , [False , True ])
119
114
@parametrize_dataset_mocks (DATASET_MOCKS )
120
- def test_serializable (self , test_home , dataset_mock , config ):
115
+ def test_traversable (self , test_home , dataset_mock , config , only_datapipe ):
121
116
dataset_mock .prepare (test_home , config )
117
+ dataset = datasets .load (dataset_mock .name , ** config )
118
+
119
+ traverse (dataset , only_datapipe = only_datapipe )
122
120
121
+ @parametrize_dataset_mocks (DATASET_MOCKS )
122
+ def test_serializable (self , test_home , dataset_mock , config ):
123
+ dataset_mock .prepare (test_home , config )
123
124
dataset = datasets .load (dataset_mock .name , ** config )
124
125
125
126
pickle .dumps (dataset )
126
127
128
+ @pytest .mark .parametrize ("num_workers" , [0 , 1 ])
129
+ @parametrize_dataset_mocks (DATASET_MOCKS )
130
+ def test_data_loader (self , test_home , dataset_mock , config , num_workers ):
131
+ dataset_mock .prepare (test_home , config )
132
+ dataset = datasets .load (dataset_mock .name , ** config )
133
+
134
+ dl = DataLoader (
135
+ dataset ,
136
+ batch_size = 2 ,
137
+ num_workers = num_workers ,
138
+ collate_fn = lambda batch : batch ,
139
+ )
140
+
141
+ next (iter (dl ))
142
+
127
143
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
128
144
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
129
145
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@@ -132,7 +148,6 @@ def test_serializable(self, test_home, dataset_mock, config):
132
148
def test_has_annotations (self , test_home , dataset_mock , config , annotation_dp_type ):
133
149
134
150
dataset_mock .prepare (test_home , config )
135
-
136
151
dataset = datasets .load (dataset_mock .name , ** config )
137
152
138
153
if not any (isinstance (dp , annotation_dp_type ) for dp in extract_datapipes (dataset )):
@@ -160,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
160
175
# resolved
161
176
assert dp .buffer_size == INFINITE_BUFFER_SIZE
162
177
178
+ @parametrize_dataset_mocks (DATASET_MOCKS )
179
+ def test_has_length (self , test_home , dataset_mock , config ):
180
+ dataset_mock .prepare (test_home , config )
181
+ dataset = datasets .load (dataset_mock .name , ** config )
182
+
183
+ assert len (dataset ) > 0
184
+
163
185
164
186
@parametrize_dataset_mocks (DATASET_MOCKS ["qmnist" ])
165
187
class TestQMNIST :
@@ -186,7 +208,7 @@ class TestGTSRB:
186
208
def test_label_matches_path (self , test_home , dataset_mock , config ):
187
209
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
188
210
# This test makes sure that they're both the same
189
- if config . split != "train" :
211
+ if config [ " split" ] != "train" :
190
212
return
191
213
192
214
dataset_mock .prepare (test_home , config )
0 commit comments