diff --git a/test/test_datasets.py b/test/test_datasets.py index bd6d1dcb259..8ebea4e9092 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1528,27 +1528,16 @@ def test_split(self, config): class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.DatasetFolder - # The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader - # that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method. - FEATURE_TYPES = (str, int) - - _IMAGE_EXTENSIONS = ("jpg", "png") - _VIDEO_EXTENSIONS = ("avi", "mp4") - _EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS) + _EXTENSIONS = ("jpg", "png") # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the # 'test_is_valid_file()' method. DEFAULT_CONFIG = dict(extensions=_EXTENSIONS) - ADDITIONAL_CONFIGS = ( - *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]), - dict(extensions=_IMAGE_EXTENSIONS), - *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]), - dict(extensions=_VIDEO_EXTENSIONS), - ) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(extensions=[(ext,) for ext in _EXTENSIONS]) def dataset_args(self, tmpdir, config): - return tmpdir, lambda x: x + return tmpdir, datasets.folder.pil_loader def inject_fake_data(self, tmpdir, config): extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) @@ -1559,14 +1548,8 @@ def inject_fake_data(self, tmpdir, config): if ext not in extensions: continue - create_example_folder = ( - datasets_utils.create_image_folder - if ext in self._IMAGE_EXTENSIONS - else datasets_utils.create_video_folder - ) - num_examples = torch.randint(1, 3, size=()).item() - create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) + datasets_utils.create_image_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) num_examples_total += num_examples classes.append(cls)