From 18799dbdf8314875cd6e8611dfb93fa13766d463 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 15:33:05 +0100 Subject: [PATCH] add test if dataset samples can be collated --- test/datasets_utils.py | 26 +++++++++++++++++ test/test_datasets.py | 38 +++++++++++++++++++++++-- test/test_prototype_builtin_datasets.py | 8 ++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index b87d50ca3db..99997ea03a6 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -23,6 +23,8 @@ import torchvision.datasets import torchvision.io from common_utils import get_tmp_dir, disable_console_output +from torch.utils.data import default_collate +from torchvision.transforms import ToTensor __all__ = [ @@ -578,6 +580,12 @@ def test_transforms(self, config): mock.assert_called() + @test_all_configs + def test_collate_smoke(self, config, transform_kwargs=None): + with self.create_dataset(config, **transform_kwargs or dict()) as (dataset, _): + batch = [dataset[0]] + default_collate(batch) + class ImageDatasetTestCase(DatasetTestCase): """Abstract base class for image dataset testcases. @@ -623,6 +631,24 @@ def new(fp, *args, **kwargs): with unittest.mock.patch("PIL.Image.open", new=new): yield + @test_all_configs + def test_collate_smoke(self, config, transform_kwargs=None): + if transform_kwargs is None: + if "transforms" in self._HAS_SPECIAL_KWARG: + + class JointTransform: + def __init__(self): + self._transform = ToTensor() + + def __call__(self, *args): + return tuple(self._transform(arg) if isinstance(arg, PIL.Image.Image) else arg for arg in args) + + transform_kwargs = dict(transforms=JointTransform()) + elif "transform" in self._HAS_SPECIAL_KWARG: + transform_kwargs = dict(transform=ToTensor()) + + super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=transform_kwargs) + class VideoDatasetTestCase(DatasetTestCase): """Abstract base class for video dataset testcases. diff --git a/test/test_datasets.py b/test/test_datasets.py index a108479aee3..b8e88d12579 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -21,6 +21,7 @@ import torch import torch.nn.functional as F from torchvision import datasets +from torchvision.transforms import ToTensor class STL10TestCase(datasets_utils.ImageDatasetTestCase): @@ -1302,6 +1303,12 @@ def test_feature_types(self, config): finally: self.FEATURE_TYPES = feature_types + @datasets_utils.test_all_configs + def test_collate_smoke(self, config): + # Unlike all other datasets, PhotoTour returns a `torch.Tensor` as image. Thus, we explicitly pass empty + # transformation keyword arguments here, to avoid trying to convert a sample `torch.Tensor` twice. + super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=dict()) + class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Flickr8k @@ -1587,6 +1594,12 @@ def test_classes(self, config): assert len(dataset.classes) == len(info["classes"]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + @datasets_utils.test_all_configs + def test_collate_smoke(self, config): + # This test case does not return PIL images, but rather strings. Thus, we explicitly pass empty + # transformation keyword arguments here, to avoid trying to convert them into a `torch.Tensor`. + super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=dict()) + class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageFolder @@ -1795,9 +1808,8 @@ def test_targets(self): assert item[6] == i // 3 -class LFWPeopleTestCase(datasets_utils.DatasetTestCase): +class LFWPeopleTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.LFWPeople - FEATURE_TYPES = (PIL.Image.Image, int) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled") ) @@ -2518,6 +2530,28 @@ def _meta_to_split_and_classification_ann(self, meta, idx): breed_id = "-1" return (image_id, class_id, species, breed_id) + @datasets_utils.test_all_configs + def test_collate_smoke(self, config): + # OxfordIIITPet can return a PIL image as target either directly or inside a tuple. Thus, we need a special + # joint transform here. + + class JointTransform: + def __init__(self): + self._transform = ToTensor() + + def __call__(self, input, target): + input = self._transform(input) + if isinstance(target, PIL.Image.Image): + target = self._transform(target) + elif isinstance(target, tuple): + target = tuple( + self._transform(feature) if isinstance(feature, PIL.Image.Image) else feature + for feature in target + ) + return input, target + + super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=dict(transforms=JointTransform())) + class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.StanfordCars diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index bebeaccaadd..ae1d357cd96 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -3,6 +3,7 @@ import pytest import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS +from torch.utils.data import default_collate from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler @@ -116,6 +117,13 @@ def scan(graph): else: raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.") + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_collate(self, dataset_mock, config): + dataset, _ = dataset_mock.load(config) + + batch = next(iter(dataset.batch(2))) + default_collate(batch) + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: