|
1 | 1 | import itertools
|
| 2 | +import pathlib |
2 | 3 | import re
|
3 | 4 | import warnings
|
4 | 5 | from collections import defaultdict
|
|
20 | 21 | make_image,
|
21 | 22 | make_images,
|
22 | 23 | make_label,
|
23 |
| - make_masks, |
24 | 24 | make_one_hot_labels,
|
25 | 25 | make_segmentation_mask,
|
26 | 26 | make_video,
|
27 | 27 | make_videos,
|
28 | 28 | )
|
| 29 | +from torch.utils._pytree import tree_flatten, tree_unflatten |
29 | 30 | from torchvision.ops.boxes import box_iou
|
30 | 31 | from torchvision.prototype import datapoints, transforms
|
31 |
| -from torchvision.prototype.transforms.utils import check_type, is_simple_tensor |
| 32 | +from torchvision.prototype.transforms import functional as F |
| 33 | +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw |
32 | 34 | from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
|
33 | 35 |
|
34 | 36 | BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
|
@@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs):
|
66 | 68 | )
|
67 | 69 |
|
68 | 70 |
|
69 |
| -def parametrize_from_transforms(*transforms): |
70 |
| - transforms_with_inputs = [] |
71 |
| - for transform in transforms: |
72 |
| - for creation_fn in [ |
73 |
| - make_images, |
74 |
| - make_bounding_boxes, |
75 |
| - make_one_hot_labels, |
76 |
| - make_vanilla_tensor_images, |
77 |
| - make_pil_images, |
78 |
| - make_masks, |
79 |
| - make_videos, |
80 |
| - ]: |
81 |
| - inputs = list(creation_fn()) |
82 |
| - try: |
83 |
| - output = transform(inputs[0]) |
84 |
| - except Exception: |
| 71 | +def auto_augment_adapter(transform, input, device): |
| 72 | + adapted_input = {} |
| 73 | + image_or_video_found = False |
| 74 | + for key, value in input.items(): |
| 75 | + if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)): |
| 76 | + # AA transforms don't support bounding boxes or masks |
| 77 | + continue |
| 78 | + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): |
| 79 | + if image_or_video_found: |
| 80 | + # AA transforms only support a single image or video |
85 | 81 | continue
|
86 |
| - else: |
87 |
| - if output is inputs[0]: |
88 |
| - continue |
| 82 | + image_or_video_found = True |
| 83 | + adapted_input[key] = value |
| 84 | + return adapted_input |
| 85 | + |
| 86 | + |
| 87 | +def linear_transformation_adapter(transform, input, device): |
| 88 | + flat_inputs = list(input.values()) |
| 89 | + c, h, w = query_chw( |
| 90 | + [ |
| 91 | + item |
| 92 | + for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs)) |
| 93 | + if needs_transform |
| 94 | + ] |
| 95 | + ) |
| 96 | + num_elements = c * h * w |
| 97 | + transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device) |
| 98 | + transform.mean_vector = torch.randn((num_elements,), device=device) |
| 99 | + return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)} |
89 | 100 |
|
90 |
| - transforms_with_inputs.append((transform, inputs)) |
91 | 101 |
|
92 |
| - return parametrize(transforms_with_inputs) |
| 102 | +def normalize_adapter(transform, input, device): |
| 103 | + adapted_input = {} |
| 104 | + for key, value in input.items(): |
| 105 | + if isinstance(value, PIL.Image.Image): |
| 106 | + # normalize doesn't support PIL images |
| 107 | + continue |
| 108 | + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): |
| 109 | + # normalize doesn't support integer images |
| 110 | + value = F.convert_dtype(value, torch.float32) |
| 111 | + adapted_input[key] = value |
| 112 | + return adapted_input |
93 | 113 |
|
94 | 114 |
|
95 | 115 | class TestSmoke:
|
96 |
| - @parametrize_from_transforms( |
97 |
| - transforms.RandomErasing(p=1.0), |
98 |
| - transforms.Resize([16, 16], antialias=True), |
99 |
| - transforms.CenterCrop([16, 16]), |
100 |
| - transforms.ConvertDtype(), |
101 |
| - transforms.RandomHorizontalFlip(), |
102 |
| - transforms.Pad(5), |
103 |
| - transforms.RandomZoomOut(), |
104 |
| - transforms.RandomRotation(degrees=(-45, 45)), |
105 |
| - transforms.RandomAffine(degrees=(-45, 45)), |
106 |
| - transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), |
107 |
| - # TODO: Something wrong with input data setup. Let's fix that |
108 |
| - # transforms.RandomEqualize(), |
109 |
| - # transforms.RandomInvert(), |
110 |
| - # transforms.RandomPosterize(bits=4), |
111 |
| - # transforms.RandomSolarize(threshold=0.5), |
112 |
| - # transforms.RandomAdjustSharpness(sharpness_factor=0.5), |
| 116 | + @pytest.mark.parametrize( |
| 117 | + ("transform", "adapter"), |
| 118 | + [ |
| 119 | + (transforms.RandomErasing(p=1.0), None), |
| 120 | + (transforms.AugMix(), auto_augment_adapter), |
| 121 | + (transforms.AutoAugment(), auto_augment_adapter), |
| 122 | + (transforms.RandAugment(), auto_augment_adapter), |
| 123 | + (transforms.TrivialAugmentWide(), auto_augment_adapter), |
| 124 | + (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), |
| 125 | + (transforms.Grayscale(), None), |
| 126 | + (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), |
| 127 | + (transforms.RandomAutocontrast(p=1.0), None), |
| 128 | + (transforms.RandomEqualize(p=1.0), None), |
| 129 | + (transforms.RandomGrayscale(p=1.0), None), |
| 130 | + (transforms.RandomInvert(p=1.0), None), |
| 131 | + (transforms.RandomPhotometricDistort(p=1.0), None), |
| 132 | + (transforms.RandomPosterize(bits=4, p=1.0), None), |
| 133 | + (transforms.RandomSolarize(threshold=0.5, p=1.0), None), |
| 134 | + (transforms.CenterCrop([16, 16]), None), |
| 135 | + (transforms.ElasticTransform(sigma=1.0), None), |
| 136 | + (transforms.Pad(4), None), |
| 137 | + (transforms.RandomAffine(degrees=30.0), None), |
| 138 | + (transforms.RandomCrop([16, 16], pad_if_needed=True), None), |
| 139 | + (transforms.RandomHorizontalFlip(p=1.0), None), |
| 140 | + (transforms.RandomPerspective(p=1.0), None), |
| 141 | + (transforms.RandomResize(min_size=10, max_size=20), None), |
| 142 | + (transforms.RandomResizedCrop([16, 16]), None), |
| 143 | + (transforms.RandomRotation(degrees=30), None), |
| 144 | + (transforms.RandomShortestSize(min_size=10), None), |
| 145 | + (transforms.RandomVerticalFlip(p=1.0), None), |
| 146 | + (transforms.RandomZoomOut(p=1.0), None), |
| 147 | + (transforms.Resize([16, 16], antialias=True), None), |
| 148 | + (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None), |
| 149 | + (transforms.ClampBoundingBoxes(), None), |
| 150 | + (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), |
| 151 | + (transforms.ConvertDtype(), None), |
| 152 | + (transforms.GaussianBlur(kernel_size=3), None), |
| 153 | + ( |
| 154 | + transforms.LinearTransformation( |
| 155 | + # These are just dummy values that will be filled by the adapter. We can't define them upfront, |
| 156 | + # because for we neither know the spatial size nor the device at this point |
| 157 | + transformation_matrix=torch.empty((1, 1)), |
| 158 | + mean_vector=torch.empty((1,)), |
| 159 | + ), |
| 160 | + linear_transformation_adapter, |
| 161 | + ), |
| 162 | + (transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter), |
| 163 | + (transforms.ToDtype(torch.float64), None), |
| 164 | + (transforms.UniformTemporalSubsample(num_samples=2), None), |
| 165 | + ], |
| 166 | + ids=lambda transform: type(transform).__name__, |
113 | 167 | )
|
114 |
| - def test_common(self, transform, input): |
115 |
| - transform(input) |
| 168 | + @pytest.mark.parametrize("container_type", [dict, list, tuple]) |
| 169 | + @pytest.mark.parametrize( |
| 170 | + "image_or_video", |
| 171 | + [ |
| 172 | + make_image(), |
| 173 | + make_video(), |
| 174 | + next(make_pil_images(color_spaces=["RGB"])), |
| 175 | + next(make_vanilla_tensor_images()), |
| 176 | + ], |
| 177 | + ) |
| 178 | + @pytest.mark.parametrize("device", cpu_and_gpu()) |
| 179 | + def test_common(self, transform, adapter, container_type, image_or_video, device): |
| 180 | + spatial_size = F.get_spatial_size(image_or_video) |
| 181 | + input = dict( |
| 182 | + image_or_video=image_or_video, |
| 183 | + image_datapoint=make_image(size=spatial_size), |
| 184 | + video_datapoint=make_video(size=spatial_size), |
| 185 | + image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])), |
| 186 | + bounding_box_xyxy=make_bounding_box( |
| 187 | + format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,) |
| 188 | + ), |
| 189 | + bounding_box_xywh=make_bounding_box( |
| 190 | + format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,) |
| 191 | + ), |
| 192 | + bounding_box_cxcywh=make_bounding_box( |
| 193 | + format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,) |
| 194 | + ), |
| 195 | + bounding_box_degenerate_xyxy=datapoints.BoundingBox( |
| 196 | + [ |
| 197 | + [0, 0, 0, 0], # no height or width |
| 198 | + [0, 0, 0, 1], # no height |
| 199 | + [0, 0, 1, 0], # no width |
| 200 | + [2, 0, 1, 1], # x1 > x2, y1 < y2 |
| 201 | + [0, 2, 1, 1], # x1 < x2, y1 > y2 |
| 202 | + [2, 2, 1, 1], # x1 > x2, y1 > y2 |
| 203 | + ], |
| 204 | + format=datapoints.BoundingBoxFormat.XYXY, |
| 205 | + spatial_size=spatial_size, |
| 206 | + ), |
| 207 | + bounding_box_degenerate_xywh=datapoints.BoundingBox( |
| 208 | + [ |
| 209 | + [0, 0, 0, 0], # no height or width |
| 210 | + [0, 0, 0, 1], # no height |
| 211 | + [0, 0, 1, 0], # no width |
| 212 | + [0, 0, 1, -1], # negative height |
| 213 | + [0, 0, -1, 1], # negative width |
| 214 | + [0, 0, -1, -1], # negative height and width |
| 215 | + ], |
| 216 | + format=datapoints.BoundingBoxFormat.XYWH, |
| 217 | + spatial_size=spatial_size, |
| 218 | + ), |
| 219 | + bounding_box_degenerate_cxcywh=datapoints.BoundingBox( |
| 220 | + [ |
| 221 | + [0, 0, 0, 0], # no height or width |
| 222 | + [0, 0, 0, 1], # no height |
| 223 | + [0, 0, 1, 0], # no width |
| 224 | + [0, 0, 1, -1], # negative height |
| 225 | + [0, 0, -1, 1], # negative width |
| 226 | + [0, 0, -1, -1], # negative height and width |
| 227 | + ], |
| 228 | + format=datapoints.BoundingBoxFormat.CXCYWH, |
| 229 | + spatial_size=spatial_size, |
| 230 | + ), |
| 231 | + detection_mask=make_detection_mask(size=spatial_size), |
| 232 | + segmentation_mask=make_segmentation_mask(size=spatial_size), |
| 233 | + int=0, |
| 234 | + float=0.0, |
| 235 | + bool=True, |
| 236 | + none=None, |
| 237 | + str="str", |
| 238 | + path=pathlib.Path.cwd(), |
| 239 | + object=object(), |
| 240 | + tensor=torch.empty(5), |
| 241 | + array=np.empty(5), |
| 242 | + ) |
| 243 | + if adapter is not None: |
| 244 | + input = adapter(transform, input, device) |
| 245 | + |
| 246 | + if container_type in {tuple, list}: |
| 247 | + input = container_type(input.values()) |
| 248 | + |
| 249 | + input_flat, input_spec = tree_flatten(input) |
| 250 | + input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat] |
| 251 | + input = tree_unflatten(input_flat, input_spec) |
| 252 | + |
| 253 | + torch.manual_seed(0) |
| 254 | + output = transform(input) |
| 255 | + output_flat, output_spec = tree_flatten(output) |
| 256 | + |
| 257 | + assert output_spec == input_spec |
| 258 | + |
| 259 | + for output_item, input_item, should_be_transformed in zip( |
| 260 | + output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat) |
| 261 | + ): |
| 262 | + if should_be_transformed: |
| 263 | + assert type(output_item) is type(input_item) |
| 264 | + else: |
| 265 | + assert output_item is input_item |
116 | 266 |
|
117 | 267 | @parametrize(
|
118 | 268 | [
|
|
0 commit comments