Skip to content

add prototype transforms that use the prototype dispatchers #5418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)

@pytest.mark.xfail
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
Expand Down
186 changes: 186 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import itertools
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test are just a quick and dirty implementation to see if the transforms don't error when called with the supported inputs.


import PIL.Image
import pytest
import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image


def make_vanilla_tensor_images(*args, **kwargs):
for image in make_images(*args, **kwargs):
if image.ndim > 3:
continue
yield image.data


def make_pil_images(*args, **kwargs):
for image in make_vanilla_tensor_images(*args, **kwargs):
yield to_pil_image(image)


def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
for bounding_box in make_bounding_boxes(*args, **kwargs):
yield bounding_box.data


INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}


def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
],
)


def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
dispatcher = transform._DISPATCHER
if dispatcher is None:
continue

for type_ in dispatcher._kernels:
try:
inputs = INPUT_CREATIONS_FNS[type_]()
except KeyError:
continue

transforms_with_inputs.append((transform, inputs))

return parametrize(transforms_with_inputs)


class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
)
def test_common(self, transform, input):
transform(input)

@parametrize(
[
(
transform,
[
dict(
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float),
one_hot_label=features.OneHotLabel.new_like(
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float
),
)
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
],
)
for transform in [
transforms.RandomMixup(alpha=1.0),
transforms.RandomCutmix(alpha=1.0),
]
]
)
def test_mixup_cutmix(self, transform, input):
transform(input)

@parametrize(
[
(
transform,
itertools.chain.from_iterable(
fn(dtypes=[torch.uint8], extra_dims=[(4,)])
for fn in [
make_images,
make_vanilla_tensor_images,
make_pil_images,
]
),
)
for transform in (
transforms.RandAugment(),
transforms.TrivialAugmentWide(),
transforms.AutoAugment(),
)
]
)
def test_auto_augment(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["rgb"], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
]
),
),
]
)
def test_normalize(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.RandomResizedCrop([16, 16]),
itertools.chain(
make_images(extra_dims=[(4,)]),
make_vanilla_tensor_images(),
make_pil_images(),
),
)
]
)
def test_random_resized_crop(self, transform, input):
transform(input)
28 changes: 25 additions & 3 deletions test/test_prototype_transforms_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.testing
import torchvision.prototype.transforms.kernels as K
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features

make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
Expand Down Expand Up @@ -39,10 +40,10 @@ def make_images(
extra_dims=((4,), (2, 3)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space)
yield make_image(size, color_space=color_space, dtype=dtype)

for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)


def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
Expand Down Expand Up @@ -106,6 +107,27 @@ def make_bounding_boxes(
yield make_bounding_box(format=format, extra_dims=extra_dims_)


def make_label(size=(), *, categories=("category0", "category1")):
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)


def make_one_hot_label(*args, **kwargs):
label = make_label(*args, **kwargs)
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)


def make_one_hot_labels(
*,
num_categories=(1, 2, 10),
extra_dims=((4,), (2, 3)),
):
for num_categories_ in num_categories:
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])

for extra_dims_ in extra_dims:
yield make_one_hot_label(extra_dims_)


class SampleInput:
def __init__(self, *args, **kwargs):
self.args = args
Expand Down
10 changes: 9 additions & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import kernels # usort: skip
from . import functional # usort: skip
from .kernels import InterpolationMode # usort: skip
from ._transform import Transform # usort: skip

from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
Loading