|
3 | 3 | import importlib |
4 | 4 | import inspect |
5 | 5 | import itertools |
6 | | -import multiprocessing |
7 | 6 | import os |
8 | 7 | import pathlib |
9 | 8 | import random |
@@ -180,27 +179,30 @@ def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_targ |
180 | 179 | from torchvision import datapoints |
181 | 180 | from torchvision.datasets import wrap_dataset_for_transforms_v2 |
182 | 181 |
|
| 182 | + def check_wrapped_samples(dataset): |
| 183 | + for wrapped_sample in dataset: |
| 184 | + assert tree_any( |
| 185 | + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample |
| 186 | + ) |
| 187 | + |
183 | 188 | target_keyss = [None] |
184 | 189 | if supports_target_keys: |
185 | 190 | target_keyss.append("all") |
186 | 191 |
|
187 | | - for target_keys, multiprocessing_context in itertools.product( |
188 | | - target_keyss, multiprocessing.get_all_start_methods() |
189 | | - ): |
| 192 | + for target_keys in target_keyss: |
190 | 193 | with dataset_test_case.create_dataset(config) as (dataset, info): |
191 | 194 | wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) |
192 | 195 |
|
193 | 196 | assert isinstance(wrapped_dataset, type(dataset)) |
194 | 197 | assert len(wrapped_dataset) == info["num_examples"] |
195 | 198 |
|
196 | | - dataloader = DataLoader( |
197 | | - wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate |
198 | | - ) |
| 199 | + check_wrapped_samples(wrapped_dataset) |
199 | 200 |
|
200 | | - for wrapped_sample in dataloader: |
201 | | - assert tree_any( |
202 | | - lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample |
203 | | - ) |
| 201 | + with dataset_test_case.create_dataset(config) as (dataset, _): |
| 202 | + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) |
| 203 | + dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) |
| 204 | + |
| 205 | + check_wrapped_samples(dataloader) |
204 | 206 |
|
205 | 207 |
|
206 | 208 | class DatasetTestCase(unittest.TestCase): |
|
0 commit comments