Skip to content

Commit e581dd0

Browse files
YosuaMichaelpmeierNicolasHug
authored andcommitted
[fbsync] Refactor and simplify prototype datasets (#5778)
Summary: * refactor prototype datasets to inherit from IterDataPipe (#5448) * refactor prototype datasets to inherit from IterDataPipe * depend on new architecture * fix missing file detection * remove unrelated file * reinstante decorator for mock registering * options -> config * remove passing of info to mock data functions * refactor categories file generation * fix imagenet * fix prototype datasets data loading tests (#5711) * reenable serialization test * cleanup * fix dill test * trigger CI * patch DILL_AVAILABLE for pickle serialization * revert CI changes * remove dill test and traversable test * add data loader test * parametrize over only_datapipe * draw one sample rather than exhaust data loader * cleanup * trigger CI * migrate VOC prototype dataset (#5743) * migrate VOC prototype dataset * cleanup * revert unrelated mock data changes * remove categories annotations * move properties to constructor * readd homepage * migrate CIFAR prototype datasets (#5751) * migrate country211 prototype dataset (#5753) * migrate CLEVR prototype datsaet (#5752) * migrate coco prototype (#5473) * migrate coco prototype * revert unrelated change * add kwargs to super constructor call * remove unneeded changes * fix docstring position * make kwargs explicit * add dependencies to docstring * fix missing dependency message * Migrate PCAM prototype dataset (#5745) * Port PCAM * skip_integrity_check * Update torchvision/prototype/datasets/_builtin/pcam.py * Address comments * Migrate DTD prototype dataset (#5757) * Migrate DTD prototype dataset * Docstring * Apply suggestions from code review * Migrate GTSRB prototype dataset (#5746) * Migrate GTSRB prototype dataset * ufmt * Address comments * Apparently mypy doesn't know that __len__ returns ints. How cute. * why is the CI not triggered?? * Update torchvision/prototype/datasets/_builtin/gtsrb.py * migrate CelebA prototype dataset (#5750) * migrate CelebA prototype dataset * inline split_id * Migrate Food101 prototype dataset (#5758) * Migrate Food101 dataset * Added length * Update torchvision/prototype/datasets/_builtin/food101.py * Migrate Fer2013 prototype dataset (#5759) * Migrate Fer2013 prototype dataset * Update torchvision/prototype/datasets/_builtin/fer2013.py * Migrate EuroSAT prototype dataset (#5760) * Migrate Semeion prototype dataset (#5761) * migrate caltech prototype datasets (#5749) * migrate caltech prototype datasets * resolve third party dependencies * Migrate Oxford Pets prototype dataset (#5764) * Migrate Oxford Pets prototype dataset * Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py * migrate mnist prototype datasets (#5480) * migrate MNIST prototype datasets * Update torchvision/prototype/datasets/_builtin/mnist.py * Migrate Stanford Cars prototype dataset (#5767) * Migrate Stanford Cars prototype dataset * Address comments * fix category file generation (#5770) * fix category file generation * revert unrelated change * revert unrelated change * migrate cub200 prototype dataset (#5765) * migrate cub200 prototype dataset * address comments * fix category-file-generation * Migrate USPS prototype dataset (#5771) * migrate SBD prototype dataset (#5772) * migrate SBD prototype dataset * reuse categories * Migrate SVHN prototype dataset (#5769) * add test to enforce __len__ is working on prototype datasets (#5742) * reactivate special dataset tests * add missing annotation * Cleanup prototype dataset implementation (#5774) * Remove Dataset2 class * Move read_categories_file out of DatasetInfo * Remove FrozenBunch and FrozenMapping * Remove test_prototype_datasets_api.py and move missing dep test somewhere else * ufmt * Let read_categories_file accept names instead of paths * Mypy * flake8 * fix category file reading * update prototype dataset README (#5777) * update prototype dataset README * fix header level * Apply suggestions from code review (Note: this ignores all push blocking failures!) Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095693 fbshipit-source-id: d57f2b4a89ef1c45f5e2783ea57bce08df5c561d Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent f9cc788 commit e581dd0

35 files changed

+1680
-1494
lines changed

test/builtin_dataset_mocks.py

Lines changed: 184 additions & 145 deletions
Large diffs are not rendered by default.

test/test_prototype_builtin_datasets.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import torch
88
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
99
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
10+
from torch.utils.data import DataLoader
1011
from torch.utils.data.graph import traverse
1112
from torch.utils.data.graph_settings import get_all_graph_pipes
12-
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
13+
from torchdata.datapipes.iter import Shuffler, ShardingFilter
1314
from torchvision._utils import sequence_to_str
1415
from torchvision.prototype import transforms, datasets
1516
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
@@ -42,14 +43,24 @@ def test_coverage():
4243

4344
@pytest.mark.filterwarnings("error")
4445
class TestCommon:
46+
@pytest.mark.parametrize("name", datasets.list_datasets())
47+
def test_info(self, name):
48+
try:
49+
info = datasets.info(name)
50+
except ValueError:
51+
raise AssertionError("No info available.") from None
52+
53+
if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())):
54+
raise AssertionError("Info should be a dictionary with string keys.")
55+
4556
@parametrize_dataset_mocks(DATASET_MOCKS)
4657
def test_smoke(self, test_home, dataset_mock, config):
4758
dataset_mock.prepare(test_home, config)
4859

4960
dataset = datasets.load(dataset_mock.name, **config)
5061

51-
if not isinstance(dataset, IterDataPipe):
52-
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
62+
if not isinstance(dataset, datasets.utils.Dataset):
63+
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
5364

5465
@parametrize_dataset_mocks(DATASET_MOCKS)
5566
def test_sample(self, test_home, dataset_mock, config):
@@ -76,24 +87,7 @@ def test_num_samples(self, test_home, dataset_mock, config):
7687

7788
dataset = datasets.load(dataset_mock.name, **config)
7889

79-
num_samples = 0
80-
for _ in dataset:
81-
num_samples += 1
82-
83-
assert num_samples == mock_info["num_samples"]
84-
85-
@parametrize_dataset_mocks(DATASET_MOCKS)
86-
def test_decoding(self, test_home, dataset_mock, config):
87-
dataset_mock.prepare(test_home, config)
88-
89-
dataset = datasets.load(dataset_mock.name, **config)
90-
91-
undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
92-
if undecoded_features:
93-
raise AssertionError(
94-
f"The values of key(s) "
95-
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
96-
)
90+
assert len(list(dataset)) == mock_info["num_samples"]
9791

9892
@parametrize_dataset_mocks(DATASET_MOCKS)
9993
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
@@ -116,14 +110,36 @@ def test_transformable(self, test_home, dataset_mock, config):
116110

117111
next(iter(dataset.map(transforms.Identity())))
118112

113+
@pytest.mark.parametrize("only_datapipe", [False, True])
119114
@parametrize_dataset_mocks(DATASET_MOCKS)
120-
def test_serializable(self, test_home, dataset_mock, config):
115+
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
121116
dataset_mock.prepare(test_home, config)
117+
dataset = datasets.load(dataset_mock.name, **config)
118+
119+
traverse(dataset, only_datapipe=only_datapipe)
122120

121+
@parametrize_dataset_mocks(DATASET_MOCKS)
122+
def test_serializable(self, test_home, dataset_mock, config):
123+
dataset_mock.prepare(test_home, config)
123124
dataset = datasets.load(dataset_mock.name, **config)
124125

125126
pickle.dumps(dataset)
126127

128+
@pytest.mark.parametrize("num_workers", [0, 1])
129+
@parametrize_dataset_mocks(DATASET_MOCKS)
130+
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
131+
dataset_mock.prepare(test_home, config)
132+
dataset = datasets.load(dataset_mock.name, **config)
133+
134+
dl = DataLoader(
135+
dataset,
136+
batch_size=2,
137+
num_workers=num_workers,
138+
collate_fn=lambda batch: batch,
139+
)
140+
141+
next(iter(dl))
142+
127143
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
128144
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
129145
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@@ -132,7 +148,6 @@ def test_serializable(self, test_home, dataset_mock, config):
132148
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
133149

134150
dataset_mock.prepare(test_home, config)
135-
136151
dataset = datasets.load(dataset_mock.name, **config)
137152

138153
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
@@ -160,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
160175
# resolved
161176
assert dp.buffer_size == INFINITE_BUFFER_SIZE
162177

178+
@parametrize_dataset_mocks(DATASET_MOCKS)
179+
def test_has_length(self, test_home, dataset_mock, config):
180+
dataset_mock.prepare(test_home, config)
181+
dataset = datasets.load(dataset_mock.name, **config)
182+
183+
assert len(dataset) > 0
184+
163185

164186
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
165187
class TestQMNIST:
@@ -186,7 +208,7 @@ class TestGTSRB:
186208
def test_label_matches_path(self, test_home, dataset_mock, config):
187209
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
188210
# This test makes sure that they're both the same
189-
if config.split != "train":
211+
if config["split"] != "train":
190212
return
191213

192214
dataset_mock.prepare(test_home, config)

test/test_prototype_datasets_api.py

Lines changed: 0 additions & 231 deletions
This file was deleted.

test/test_prototype_datasets_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from datasets_utils import make_fake_flo_file
77
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
8-
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
8+
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
99
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
1010

1111

@@ -101,3 +101,21 @@ def preprocess_sentinel(path):
101101
assert redirected_resource.file_name == file_name
102102
assert redirected_resource.sha256 == sha256_sentinel
103103
assert redirected_resource._preprocess is preprocess_sentinel
104+
105+
106+
def test_missing_dependency_error():
107+
class DummyDataset(Dataset):
108+
def __init__(self):
109+
super().__init__(root="root", dependencies=("fake_dependency",))
110+
111+
def _resources(self):
112+
pass
113+
114+
def _datapipe(self, resource_dps):
115+
pass
116+
117+
def __len__(self):
118+
pass
119+
120+
with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
121+
DummyDataset()

0 commit comments

Comments
 (0)