Skip to content

add test if dataset samples can be collated #5233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torchvision.datasets
import torchvision.io
from common_utils import get_tmp_dir, disable_console_output
from torch.utils.data import default_collate
from torchvision.transforms import ToTensor


__all__ = [
Expand Down Expand Up @@ -578,6 +580,12 @@ def test_transforms(self, config):

mock.assert_called()

@test_all_configs
def test_collate_smoke(self, config, transform_kwargs=None):
with self.create_dataset(config, **transform_kwargs or dict()) as (dataset, _):
batch = [dataset[0]]
default_collate(batch)


class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.
Expand Down Expand Up @@ -623,6 +631,24 @@ def new(fp, *args, **kwargs):
with unittest.mock.patch("PIL.Image.open", new=new):
yield

@test_all_configs
def test_collate_smoke(self, config, transform_kwargs=None):
if transform_kwargs is None:
if "transforms" in self._HAS_SPECIAL_KWARG:

class JointTransform:
def __init__(self):
self._transform = ToTensor()

def __call__(self, *args):
return tuple(self._transform(arg) if isinstance(arg, PIL.Image.Image) else arg for arg in args)

transform_kwargs = dict(transforms=JointTransform())
elif "transform" in self._HAS_SPECIAL_KWARG:
transform_kwargs = dict(transform=ToTensor())

super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=transform_kwargs)


class VideoDatasetTestCase(DatasetTestCase):
"""Abstract base class for video dataset testcases.
Expand Down
38 changes: 36 additions & 2 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import ToTensor


class STL10TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1302,6 +1303,12 @@ def test_feature_types(self, config):
finally:
self.FEATURE_TYPES = feature_types

@datasets_utils.test_all_configs
def test_collate_smoke(self, config):
# Unlike all other datasets, PhotoTour returns a `torch.Tensor` as image. Thus, we explicitly pass empty
# transformation keyword arguments here, to avoid trying to convert a sample `torch.Tensor` twice.
super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=dict())


class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Flickr8k
Expand Down Expand Up @@ -1587,6 +1594,12 @@ def test_classes(self, config):
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])

@datasets_utils.test_all_configs
def test_collate_smoke(self, config):
# This test case does not return PIL images, but rather strings. Thus, we explicitly pass empty
# transformation keyword arguments here, to avoid trying to convert them into a `torch.Tensor`.
super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=dict())


class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageFolder
Expand Down Expand Up @@ -1795,9 +1808,8 @@ def test_targets(self):
assert item[6] == i // 3


class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
class LFWPeopleTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LFWPeople
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled")
)
Expand Down Expand Up @@ -2518,6 +2530,28 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
breed_id = "-1"
return (image_id, class_id, species, breed_id)

@datasets_utils.test_all_configs
def test_collate_smoke(self, config):
# OxfordIIITPet can return a PIL image as target either directly or inside a tuple. Thus, we need a special
# joint transform here.

class JointTransform:
def __init__(self):
self._transform = ToTensor()

def __call__(self, input, target):
input = self._transform(input)
if isinstance(target, PIL.Image.Image):
target = self._transform(target)
elif isinstance(target, tuple):
target = tuple(
self._transform(feature) if isinstance(feature, PIL.Image.Image) else feature
for feature in target
)
return input, target

super().test_collate_smoke.__wrapped__(self, config, transform_kwargs=dict(transforms=JointTransform()))


class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars
Expand Down
8 changes: 8 additions & 0 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.utils.data import default_collate
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
Expand Down Expand Up @@ -116,6 +117,13 @@ def scan(graph):
else:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_collate(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)

batch = next(iter(dataset.batch(2)))
default_collate(batch)


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
Expand Down