diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 0ed51c44d77..7030d2d1b2e 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,4 +1,5 @@ import itertools +import pathlib import re import warnings from collections import defaultdict @@ -20,15 +21,16 @@ make_image, make_images, make_label, - make_masks, make_one_hot_labels, make_segmentation_mask, make_video, make_videos, ) +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints, transforms -from torchvision.prototype.transforms.utils import check_type, is_simple_tensor +from torchvision.prototype.transforms import functional as F +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs): ) -def parametrize_from_transforms(*transforms): - transforms_with_inputs = [] - for transform in transforms: - for creation_fn in [ - make_images, - make_bounding_boxes, - make_one_hot_labels, - make_vanilla_tensor_images, - make_pil_images, - make_masks, - make_videos, - ]: - inputs = list(creation_fn()) - try: - output = transform(inputs[0]) - except Exception: +def auto_augment_adapter(transform, input, device): + adapted_input = {} + image_or_video_found = False + for key, value in input.items(): + if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)): + # AA transforms don't support bounding boxes or masks + continue + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): + if image_or_video_found: + # AA transforms only support a single image or video continue - else: - if output is inputs[0]: - continue + image_or_video_found = True + adapted_input[key] = value + return adapted_input + + +def linear_transformation_adapter(transform, input, device): + flat_inputs = list(input.values()) + c, h, w = query_chw( + [ + item + for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs)) + if needs_transform + ] + ) + num_elements = c * h * w + transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device) + transform.mean_vector = torch.randn((num_elements,), device=device) + return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)} - transforms_with_inputs.append((transform, inputs)) - return parametrize(transforms_with_inputs) +def normalize_adapter(transform, input, device): + adapted_input = {} + for key, value in input.items(): + if isinstance(value, PIL.Image.Image): + # normalize doesn't support PIL images + continue + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): + # normalize doesn't support integer images + value = F.convert_dtype(value, torch.float32) + adapted_input[key] = value + return adapted_input class TestSmoke: - @parametrize_from_transforms( - transforms.RandomErasing(p=1.0), - transforms.Resize([16, 16], antialias=True), - transforms.CenterCrop([16, 16]), - transforms.ConvertDtype(), - transforms.RandomHorizontalFlip(), - transforms.Pad(5), - transforms.RandomZoomOut(), - transforms.RandomRotation(degrees=(-45, 45)), - transforms.RandomAffine(degrees=(-45, 45)), - transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), - # TODO: Something wrong with input data setup. Let's fix that - # transforms.RandomEqualize(), - # transforms.RandomInvert(), - # transforms.RandomPosterize(bits=4), - # transforms.RandomSolarize(threshold=0.5), - # transforms.RandomAdjustSharpness(sharpness_factor=0.5), + @pytest.mark.parametrize( + ("transform", "adapter"), + [ + (transforms.RandomErasing(p=1.0), None), + (transforms.AugMix(), auto_augment_adapter), + (transforms.AutoAugment(), auto_augment_adapter), + (transforms.RandAugment(), auto_augment_adapter), + (transforms.TrivialAugmentWide(), auto_augment_adapter), + (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), + (transforms.Grayscale(), None), + (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), + (transforms.RandomAutocontrast(p=1.0), None), + (transforms.RandomEqualize(p=1.0), None), + (transforms.RandomGrayscale(p=1.0), None), + (transforms.RandomInvert(p=1.0), None), + (transforms.RandomPhotometricDistort(p=1.0), None), + (transforms.RandomPosterize(bits=4, p=1.0), None), + (transforms.RandomSolarize(threshold=0.5, p=1.0), None), + (transforms.CenterCrop([16, 16]), None), + (transforms.ElasticTransform(sigma=1.0), None), + (transforms.Pad(4), None), + (transforms.RandomAffine(degrees=30.0), None), + (transforms.RandomCrop([16, 16], pad_if_needed=True), None), + (transforms.RandomHorizontalFlip(p=1.0), None), + (transforms.RandomPerspective(p=1.0), None), + (transforms.RandomResize(min_size=10, max_size=20), None), + (transforms.RandomResizedCrop([16, 16]), None), + (transforms.RandomRotation(degrees=30), None), + (transforms.RandomShortestSize(min_size=10), None), + (transforms.RandomVerticalFlip(p=1.0), None), + (transforms.RandomZoomOut(p=1.0), None), + (transforms.Resize([16, 16], antialias=True), None), + (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None), + (transforms.ClampBoundingBoxes(), None), + (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), + (transforms.ConvertDtype(), None), + (transforms.GaussianBlur(kernel_size=3), None), + ( + transforms.LinearTransformation( + # These are just dummy values that will be filled by the adapter. We can't define them upfront, + # because for we neither know the spatial size nor the device at this point + transformation_matrix=torch.empty((1, 1)), + mean_vector=torch.empty((1,)), + ), + linear_transformation_adapter, + ), + (transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter), + (transforms.ToDtype(torch.float64), None), + (transforms.UniformTemporalSubsample(num_samples=2), None), + ], + ids=lambda transform: type(transform).__name__, ) - def test_common(self, transform, input): - transform(input) + @pytest.mark.parametrize("container_type", [dict, list, tuple]) + @pytest.mark.parametrize( + "image_or_video", + [ + make_image(), + make_video(), + next(make_pil_images(color_spaces=["RGB"])), + next(make_vanilla_tensor_images()), + ], + ) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_common(self, transform, adapter, container_type, image_or_video, device): + spatial_size = F.get_spatial_size(image_or_video) + input = dict( + image_or_video=image_or_video, + image_datapoint=make_image(size=spatial_size), + video_datapoint=make_video(size=spatial_size), + image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), + bounding_box_xyxy=make_bounding_box( + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,) + ), + bounding_box_xywh=make_bounding_box( + format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,) + ), + bounding_box_cxcywh=make_bounding_box( + format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,) + ), + bounding_box_degenerate_xyxy=datapoints.BoundingBox( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [2, 0, 1, 1], # x1 > x2, y1 < y2 + [0, 2, 1, 1], # x1 < x2, y1 > y2 + [2, 2, 1, 1], # x1 > x2, y1 > y2 + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ), + bounding_box_degenerate_xywh=datapoints.BoundingBox( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [0, 0, 1, -1], # negative height + [0, 0, -1, 1], # negative width + [0, 0, -1, -1], # negative height and width + ], + format=datapoints.BoundingBoxFormat.XYWH, + spatial_size=spatial_size, + ), + bounding_box_degenerate_cxcywh=datapoints.BoundingBox( + [ + [0, 0, 0, 0], # no height or width + [0, 0, 0, 1], # no height + [0, 0, 1, 0], # no width + [0, 0, 1, -1], # negative height + [0, 0, -1, 1], # negative width + [0, 0, -1, -1], # negative height and width + ], + format=datapoints.BoundingBoxFormat.CXCYWH, + spatial_size=spatial_size, + ), + detection_mask=make_detection_mask(size=spatial_size), + segmentation_mask=make_segmentation_mask(size=spatial_size), + int=0, + float=0.0, + bool=True, + none=None, + str="str", + path=pathlib.Path.cwd(), + object=object(), + tensor=torch.empty(5), + array=np.empty(5), + ) + if adapter is not None: + input = adapter(transform, input, device) + + if container_type in {tuple, list}: + input = container_type(input.values()) + + input_flat, input_spec = tree_flatten(input) + input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat] + input = tree_unflatten(input_flat, input_spec) + + torch.manual_seed(0) + output = transform(input) + output_flat, output_spec = tree_flatten(output) + + assert output_spec == input_spec + + for output_item, input_item, should_be_transformed in zip( + output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat) + ): + if should_be_transformed: + assert type(output_item) is type(input_item) + else: + assert output_item is input_item @parametrize( [