Skip to content

add proper smoke test for prototype transforms #7238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 14, 2023
Merged
224 changes: 182 additions & 42 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import pathlib
import re
import warnings
from collections import defaultdict
Expand All @@ -20,15 +21,16 @@
make_image,
make_images,
make_label,
make_masks,
make_one_hot_labels,
make_segmentation_mask,
make_video,
make_videos,
)
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -66,53 +68,191 @@ def parametrize(transforms_with_inputs):
)


def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_masks,
make_videos,
]:
inputs = list(creation_fn())
try:
output = transform(inputs[0])
except Exception:
def auto_augment_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
else:
if output is inputs[0]:
continue
image_or_video_found = True
adapted_input[key] = value
return adapted_input

transforms_with_inputs.append((transform, inputs))

return parametrize(transforms_with_inputs)
def linear_transformation_adapter(transform, input, device):
c, h, w = query_chw(input.values())
num_elements = c * h * w
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
transform.mean_vector = torch.randn((num_elements,), device=device)
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}


def normalize_adapter(transform, input, device):
adapted_input = {}
for key, value in input.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
# normalize doesn't support integer images
value = F.convert_dtype(value, torch.float32)
adapted_input[key] = value
return adapted_input


class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16], antialias=True),
transforms.CenterCrop([16, 16]),
transforms.ConvertDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
# TODO: Something wrong with input data setup. Let's fix that
# transforms.RandomEqualize(),
# transforms.RandomInvert(),
# transforms.RandomPosterize(bits=4),
# transforms.RandomSolarize(threshold=0.5),
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
@pytest.mark.parametrize(
("transform", "adapter"),
[
(transforms.RandomErasing(p=1.0), None),
(transforms.AugMix(), auto_augment_adapter),
(transforms.AutoAugment(), auto_augment_adapter),
(transforms.RandAugment(), auto_augment_adapter),
(transforms.TrivialAugmentWide(), auto_augment_adapter),
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
(transforms.Grayscale(), None),
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
(transforms.RandomAutocontrast(p=1.0), None),
(transforms.RandomEqualize(p=1.0), None),
(transforms.RandomGrayscale(p=1.0), None),
(transforms.RandomInvert(p=1.0), None),
(transforms.RandomPhotometricDistort(p=1.0), None),
(transforms.RandomPosterize(bits=4, p=1.0), None),
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
(transforms.CenterCrop([16, 16]), None),
(transforms.ElasticTransform(sigma=1.0), None),
(transforms.Pad(4), None),
(transforms.RandomAffine(degrees=30.0), None),
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
(transforms.RandomHorizontalFlip(p=1.0), None),
(transforms.RandomPerspective(p=1.0), None),
(transforms.RandomResize(min_size=10, max_size=20), None),
(transforms.RandomResizedCrop([16, 16]), None),
(transforms.RandomRotation(degrees=30), None),
(transforms.RandomShortestSize(min_size=10), None),
(transforms.RandomVerticalFlip(p=1.0), None),
(transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16)), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix=torch.empty((1, 1)),
mean_vector=torch.empty((1,)),
),
linear_transformation_adapter,
),
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
(transforms.ToDtype(torch.float64), None),
(transforms.UniformTemporalSubsample(num_samples=2), None),
],
ids=lambda transform: type(transform).__name__,
)
def test_common(self, transform, input):
transform(input)
@pytest.mark.parametrize("container_type", [dict, list, tuple])
@pytest.mark.parametrize(
"image_or_video",
[
make_image(),
make_video(),
next(make_pil_images(color_spaces=["RGB"])),
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_common(self, transform, adapter, container_type, image_or_video, device):
spatial_size = F.get_spatial_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_datapoint=make_image(size=spatial_size),
video_datapoint=make_video(size=spatial_size),
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])),
bounding_box_xyxy=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,)
),
bounding_box_xywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,)
),
bounding_box_cxcywh=make_bounding_box(
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,)
),
bounding_box_degenerate_xyxy=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[2, 0, 1, 1], # x1 > x2, y1 < y2
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
),
bounding_box_degenerate_xywh=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.XYWH,
spatial_size=spatial_size,
),
bounding_box_degenerate_cxcywh=datapoints.BoundingBox(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
[0, 0, 1, 0], # no width
[0, 0, 1, -1], # negative height
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.CXCYWH,
spatial_size=spatial_size,
),
detection_mask=make_detection_mask(size=spatial_size),
segmentation_mask=make_segmentation_mask(size=spatial_size),
int=0,
float=0.0,
bool=True,
none=None,
str="str",
path=pathlib.Path.cwd(),
object=object(),
tensor=torch.empty(5),
array=np.empty(5),
)
if adapter is not None:
input = adapter(transform, input, device)

if container_type in {tuple, list}:
input = container_type(input.values())

input_flat, input_spec = tree_flatten(input)
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
input = tree_unflatten(input_flat, input_spec)

output = transform(input)
output_flat, output_spec = tree_flatten(output)

assert output_spec == input_spec

for output_item, input_item in zip(output_flat, input_flat):
if check_type(input_item, transform._transformed_types):
assert type(output_item) is type(input_item)
else:
assert output_item is input_item

@parametrize(
[
Expand Down