Skip to content

Commit c629504

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] add proper smoke test for prototype transforms (#7238)
Reviewed By: vmoens Differential Revision: D44416579 fbshipit-source-id: 9c3b0da79fe1270c13b6f705c5894ddd7783911f
1 parent ac6942e commit c629504

File tree

1 file changed

+192
-42
lines changed

1 file changed

+192
-42
lines changed

test/test_prototype_transforms.py

Lines changed: 192 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import pathlib
23
import re
34
import warnings
45
from collections import defaultdict
@@ -20,15 +21,16 @@
2021
make_image,
2122
make_images,
2223
make_label,
23-
make_masks,
2424
make_one_hot_labels,
2525
make_segmentation_mask,
2626
make_video,
2727
make_videos,
2828
)
29+
from torch.utils._pytree import tree_flatten, tree_unflatten
2930
from torchvision.ops.boxes import box_iou
3031
from torchvision.prototype import datapoints, transforms
31-
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
32+
from torchvision.prototype.transforms import functional as F
33+
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw
3234
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
3335

3436
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
@@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs):
6668
)
6769

6870

69-
def parametrize_from_transforms(*transforms):
70-
transforms_with_inputs = []
71-
for transform in transforms:
72-
for creation_fn in [
73-
make_images,
74-
make_bounding_boxes,
75-
make_one_hot_labels,
76-
make_vanilla_tensor_images,
77-
make_pil_images,
78-
make_masks,
79-
make_videos,
80-
]:
81-
inputs = list(creation_fn())
82-
try:
83-
output = transform(inputs[0])
84-
except Exception:
71+
def auto_augment_adapter(transform, input, device):
72+
adapted_input = {}
73+
image_or_video_found = False
74+
for key, value in input.items():
75+
if isinstance(value, (datapoints.BoundingBox, datapoints.Mask)):
76+
# AA transforms don't support bounding boxes or masks
77+
continue
78+
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
79+
if image_or_video_found:
80+
# AA transforms only support a single image or video
8581
continue
86-
else:
87-
if output is inputs[0]:
88-
continue
82+
image_or_video_found = True
83+
adapted_input[key] = value
84+
return adapted_input
85+
86+
87+
def linear_transformation_adapter(transform, input, device):
88+
flat_inputs = list(input.values())
89+
c, h, w = query_chw(
90+
[
91+
item
92+
for item, needs_transform in zip(flat_inputs, transforms.Transform()._needs_transform_list(flat_inputs))
93+
if needs_transform
94+
]
95+
)
96+
num_elements = c * h * w
97+
transform.transformation_matrix = torch.randn((num_elements, num_elements), device=device)
98+
transform.mean_vector = torch.randn((num_elements,), device=device)
99+
return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
89100

90-
transforms_with_inputs.append((transform, inputs))
91101

92-
return parametrize(transforms_with_inputs)
102+
def normalize_adapter(transform, input, device):
103+
adapted_input = {}
104+
for key, value in input.items():
105+
if isinstance(value, PIL.Image.Image):
106+
# normalize doesn't support PIL images
107+
continue
108+
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
109+
# normalize doesn't support integer images
110+
value = F.convert_dtype(value, torch.float32)
111+
adapted_input[key] = value
112+
return adapted_input
93113

94114

95115
class TestSmoke:
96-
@parametrize_from_transforms(
97-
transforms.RandomErasing(p=1.0),
98-
transforms.Resize([16, 16], antialias=True),
99-
transforms.CenterCrop([16, 16]),
100-
transforms.ConvertDtype(),
101-
transforms.RandomHorizontalFlip(),
102-
transforms.Pad(5),
103-
transforms.RandomZoomOut(),
104-
transforms.RandomRotation(degrees=(-45, 45)),
105-
transforms.RandomAffine(degrees=(-45, 45)),
106-
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
107-
# TODO: Something wrong with input data setup. Let's fix that
108-
# transforms.RandomEqualize(),
109-
# transforms.RandomInvert(),
110-
# transforms.RandomPosterize(bits=4),
111-
# transforms.RandomSolarize(threshold=0.5),
112-
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
116+
@pytest.mark.parametrize(
117+
("transform", "adapter"),
118+
[
119+
(transforms.RandomErasing(p=1.0), None),
120+
(transforms.AugMix(), auto_augment_adapter),
121+
(transforms.AutoAugment(), auto_augment_adapter),
122+
(transforms.RandAugment(), auto_augment_adapter),
123+
(transforms.TrivialAugmentWide(), auto_augment_adapter),
124+
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
125+
(transforms.Grayscale(), None),
126+
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
127+
(transforms.RandomAutocontrast(p=1.0), None),
128+
(transforms.RandomEqualize(p=1.0), None),
129+
(transforms.RandomGrayscale(p=1.0), None),
130+
(transforms.RandomInvert(p=1.0), None),
131+
(transforms.RandomPhotometricDistort(p=1.0), None),
132+
(transforms.RandomPosterize(bits=4, p=1.0), None),
133+
(transforms.RandomSolarize(threshold=0.5, p=1.0), None),
134+
(transforms.CenterCrop([16, 16]), None),
135+
(transforms.ElasticTransform(sigma=1.0), None),
136+
(transforms.Pad(4), None),
137+
(transforms.RandomAffine(degrees=30.0), None),
138+
(transforms.RandomCrop([16, 16], pad_if_needed=True), None),
139+
(transforms.RandomHorizontalFlip(p=1.0), None),
140+
(transforms.RandomPerspective(p=1.0), None),
141+
(transforms.RandomResize(min_size=10, max_size=20), None),
142+
(transforms.RandomResizedCrop([16, 16]), None),
143+
(transforms.RandomRotation(degrees=30), None),
144+
(transforms.RandomShortestSize(min_size=10), None),
145+
(transforms.RandomVerticalFlip(p=1.0), None),
146+
(transforms.RandomZoomOut(p=1.0), None),
147+
(transforms.Resize([16, 16], antialias=True), None),
148+
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None),
149+
(transforms.ClampBoundingBoxes(), None),
150+
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
151+
(transforms.ConvertDtype(), None),
152+
(transforms.GaussianBlur(kernel_size=3), None),
153+
(
154+
transforms.LinearTransformation(
155+
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
156+
# because for we neither know the spatial size nor the device at this point
157+
transformation_matrix=torch.empty((1, 1)),
158+
mean_vector=torch.empty((1,)),
159+
),
160+
linear_transformation_adapter,
161+
),
162+
(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), normalize_adapter),
163+
(transforms.ToDtype(torch.float64), None),
164+
(transforms.UniformTemporalSubsample(num_samples=2), None),
165+
],
166+
ids=lambda transform: type(transform).__name__,
113167
)
114-
def test_common(self, transform, input):
115-
transform(input)
168+
@pytest.mark.parametrize("container_type", [dict, list, tuple])
169+
@pytest.mark.parametrize(
170+
"image_or_video",
171+
[
172+
make_image(),
173+
make_video(),
174+
next(make_pil_images(color_spaces=["RGB"])),
175+
next(make_vanilla_tensor_images()),
176+
],
177+
)
178+
@pytest.mark.parametrize("device", cpu_and_gpu())
179+
def test_common(self, transform, adapter, container_type, image_or_video, device):
180+
spatial_size = F.get_spatial_size(image_or_video)
181+
input = dict(
182+
image_or_video=image_or_video,
183+
image_datapoint=make_image(size=spatial_size),
184+
video_datapoint=make_video(size=spatial_size),
185+
image_pil=next(make_pil_images(sizes=[spatial_size], color_spaces=["RGB"])),
186+
bounding_box_xyxy=make_bounding_box(
187+
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(3,)
188+
),
189+
bounding_box_xywh=make_bounding_box(
190+
format=datapoints.BoundingBoxFormat.XYWH, spatial_size=spatial_size, extra_dims=(4,)
191+
),
192+
bounding_box_cxcywh=make_bounding_box(
193+
format=datapoints.BoundingBoxFormat.CXCYWH, spatial_size=spatial_size, extra_dims=(5,)
194+
),
195+
bounding_box_degenerate_xyxy=datapoints.BoundingBox(
196+
[
197+
[0, 0, 0, 0], # no height or width
198+
[0, 0, 0, 1], # no height
199+
[0, 0, 1, 0], # no width
200+
[2, 0, 1, 1], # x1 > x2, y1 < y2
201+
[0, 2, 1, 1], # x1 < x2, y1 > y2
202+
[2, 2, 1, 1], # x1 > x2, y1 > y2
203+
],
204+
format=datapoints.BoundingBoxFormat.XYXY,
205+
spatial_size=spatial_size,
206+
),
207+
bounding_box_degenerate_xywh=datapoints.BoundingBox(
208+
[
209+
[0, 0, 0, 0], # no height or width
210+
[0, 0, 0, 1], # no height
211+
[0, 0, 1, 0], # no width
212+
[0, 0, 1, -1], # negative height
213+
[0, 0, -1, 1], # negative width
214+
[0, 0, -1, -1], # negative height and width
215+
],
216+
format=datapoints.BoundingBoxFormat.XYWH,
217+
spatial_size=spatial_size,
218+
),
219+
bounding_box_degenerate_cxcywh=datapoints.BoundingBox(
220+
[
221+
[0, 0, 0, 0], # no height or width
222+
[0, 0, 0, 1], # no height
223+
[0, 0, 1, 0], # no width
224+
[0, 0, 1, -1], # negative height
225+
[0, 0, -1, 1], # negative width
226+
[0, 0, -1, -1], # negative height and width
227+
],
228+
format=datapoints.BoundingBoxFormat.CXCYWH,
229+
spatial_size=spatial_size,
230+
),
231+
detection_mask=make_detection_mask(size=spatial_size),
232+
segmentation_mask=make_segmentation_mask(size=spatial_size),
233+
int=0,
234+
float=0.0,
235+
bool=True,
236+
none=None,
237+
str="str",
238+
path=pathlib.Path.cwd(),
239+
object=object(),
240+
tensor=torch.empty(5),
241+
array=np.empty(5),
242+
)
243+
if adapter is not None:
244+
input = adapter(transform, input, device)
245+
246+
if container_type in {tuple, list}:
247+
input = container_type(input.values())
248+
249+
input_flat, input_spec = tree_flatten(input)
250+
input_flat = [item.to(device) if isinstance(item, torch.Tensor) else item for item in input_flat]
251+
input = tree_unflatten(input_flat, input_spec)
252+
253+
torch.manual_seed(0)
254+
output = transform(input)
255+
output_flat, output_spec = tree_flatten(output)
256+
257+
assert output_spec == input_spec
258+
259+
for output_item, input_item, should_be_transformed in zip(
260+
output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
261+
):
262+
if should_be_transformed:
263+
assert type(output_item) is type(input_item)
264+
else:
265+
assert output_item is input_item
116266

117267
@parametrize(
118268
[

0 commit comments

Comments
 (0)