Skip to content

Commit 6d43f4a

Browse files
authored
Merge branch 'main' into proto-mask-affine
2 parents d17decb + be462be commit 6d43f4a

File tree

3 files changed

+254
-22
lines changed

3 files changed

+254
-22
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 167 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,20 @@ def affine_segmentation_mask():
279279
translate=(translate, translate),
280280
scale=scale,
281281
shear=(shear, shear),
282+
283+
284+
@register_kernel_info_from_sample_inputs_fn
285+
def rotate_bounding_box():
286+
for bounding_box, angle, expand, center in itertools.product(
287+
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center
288+
):
289+
yield SampleInput(
290+
bounding_box,
291+
format=bounding_box.format,
292+
image_size=bounding_box.image_size,
293+
angle=angle,
294+
expand=expand,
295+
center=center,
282296
)
283297

284298

@@ -364,7 +378,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
364378
np.max(transformed_points[:, 1]),
365379
]
366380
out_bbox = features.BoundingBox(
367-
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32
381+
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=bbox.image_size, dtype=torch.float32
368382
)
369383
out_bbox = convert_bounding_box_format(
370384
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
@@ -379,25 +393,25 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
379393
],
380394
extra_dims=((4,),),
381395
):
396+
bboxes_format = bboxes.format
397+
bboxes_image_size = bboxes.image_size
398+
382399
output_bboxes = F.affine_bounding_box(
383400
bboxes,
384-
bboxes.format,
385-
image_size=image_size,
401+
bboxes_format,
402+
image_size=bboxes_image_size,
386403
angle=angle,
387404
translate=(translate, translate),
388405
scale=scale,
389406
shear=(shear, shear),
390407
center=center,
391408
)
409+
392410
if center is None:
393-
center = [s // 2 for s in image_size[::-1]]
411+
center = [s // 2 for s in bboxes_image_size[::-1]]
394412

395-
bboxes_format = bboxes.format
396-
bboxes_image_size = bboxes.image_size
397413
if bboxes.ndim < 2:
398-
bboxes = [
399-
bboxes,
400-
]
414+
bboxes = [bboxes]
401415

402416
expected_bboxes = []
403417
for bbox in bboxes:
@@ -531,3 +545,147 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):
531545
out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
532546

533547
torch.testing.assert_close(out_mask, expected_mask)
548+
549+
550+
@pytest.mark.parametrize("angle", range(-90, 90, 56))
551+
@pytest.mark.parametrize("expand", [True, False])
552+
@pytest.mark.parametrize("center", [None, (12, 14)])
553+
def test_correctness_rotate_bounding_box(angle, expand, center):
554+
def _compute_expected_bbox(bbox, angle_, expand_, center_):
555+
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
556+
affine_matrix = affine_matrix[:2, :]
557+
558+
image_size = bbox.image_size
559+
bbox_xyxy = convert_bounding_box_format(
560+
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
561+
)
562+
points = np.array(
563+
[
564+
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
565+
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
566+
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
567+
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
568+
# image frame
569+
[0.0, 0.0, 1.0],
570+
[0.0, image_size[0], 1.0],
571+
[image_size[1], image_size[0], 1.0],
572+
[image_size[1], 0.0, 1.0],
573+
]
574+
)
575+
transformed_points = np.matmul(points, affine_matrix.T)
576+
out_bbox = [
577+
np.min(transformed_points[:4, 0]),
578+
np.min(transformed_points[:4, 1]),
579+
np.max(transformed_points[:4, 0]),
580+
np.max(transformed_points[:4, 1]),
581+
]
582+
if expand_:
583+
tr_x = np.min(transformed_points[4:, 0])
584+
tr_y = np.min(transformed_points[4:, 1])
585+
out_bbox[0] -= tr_x
586+
out_bbox[1] -= tr_y
587+
out_bbox[2] -= tr_x
588+
out_bbox[3] -= tr_y
589+
590+
out_bbox = features.BoundingBox(
591+
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float32
592+
)
593+
out_bbox = convert_bounding_box_format(
594+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
595+
)
596+
return out_bbox.to(bbox.device)
597+
598+
image_size = (32, 38)
599+
600+
for bboxes in make_bounding_boxes(
601+
image_sizes=[
602+
image_size,
603+
],
604+
extra_dims=((4,),),
605+
):
606+
bboxes_format = bboxes.format
607+
bboxes_image_size = bboxes.image_size
608+
609+
output_bboxes = F.rotate_bounding_box(
610+
bboxes,
611+
bboxes_format,
612+
image_size=bboxes_image_size,
613+
angle=angle,
614+
expand=expand,
615+
center=center,
616+
)
617+
618+
if center is None:
619+
center = [s // 2 for s in bboxes_image_size[::-1]]
620+
621+
if bboxes.ndim < 2:
622+
bboxes = [bboxes]
623+
624+
expected_bboxes = []
625+
for bbox in bboxes:
626+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
627+
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center))
628+
if len(expected_bboxes) > 1:
629+
expected_bboxes = torch.stack(expected_bboxes)
630+
else:
631+
expected_bboxes = expected_bboxes[0]
632+
print("input:", bboxes)
633+
print("output_bboxes:", output_bboxes)
634+
print("expected_bboxes:", expected_bboxes)
635+
torch.testing.assert_close(output_bboxes, expected_bboxes)
636+
637+
638+
@pytest.mark.parametrize("device", cpu_and_gpu())
639+
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress
640+
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
641+
# Check transformation against known expected output
642+
image_size = (64, 64)
643+
# xyxy format
644+
in_boxes = [
645+
[1, 1, 5, 5],
646+
[1, image_size[0] - 6, 5, image_size[0] - 2],
647+
[image_size[1] - 6, image_size[0] - 6, image_size[1] - 2, image_size[0] - 2],
648+
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10],
649+
]
650+
in_boxes = features.BoundingBox(
651+
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64
652+
).to(device)
653+
# Tested parameters
654+
angle = 45
655+
center = None if expand else [12, 23]
656+
657+
# # Expected bboxes computed using Detectron2:
658+
# from detectron2.data.transforms import RotationTransform, AugmentationList
659+
# from detectron2.data.transforms import AugInput
660+
# import cv2
661+
# inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32"))
662+
# augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ])
663+
# out = augs(inpt)
664+
# print(inpt.boxes)
665+
if expand:
666+
expected_bboxes = [
667+
[1.65937957, 42.67157288, 7.31623382, 48.32842712],
668+
[41.96446609, 82.9766594, 47.62132034, 88.63351365],
669+
[82.26955262, 42.67157288, 87.92640687, 48.32842712],
670+
[31.35786438, 31.35786438, 59.64213562, 59.64213562],
671+
]
672+
else:
673+
expected_bboxes = [
674+
[-11.33452378, 12.39339828, -5.67766953, 18.05025253],
675+
[28.97056275, 52.69848481, 34.627417, 58.35533906],
676+
[69.27564928, 12.39339828, 74.93250353, 18.05025253],
677+
[18.36396103, 1.07968978, 46.64823228, 29.36396103],
678+
]
679+
680+
output_boxes = F.rotate_bounding_box(
681+
in_boxes,
682+
in_boxes.format,
683+
in_boxes.image_size,
684+
angle,
685+
expand=expand,
686+
center=center,
687+
)
688+
689+
assert len(output_boxes) == len(expected_bboxes)
690+
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
691+
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
affine_image_tensor,
5454
affine_image_pil,
5555
affine_segmentation_mask,
56+
rotate_bounding_box,
5657
rotate_image_tensor,
5758
rotate_image_pil,
5859
pad_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numbers
2+
import warnings
23
from typing import Tuple, List, Optional, Sequence, Union
34

45
import PIL.Image
@@ -197,24 +198,28 @@ def affine_image_pil(
197198
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
198199

199200

200-
def affine_bounding_box(
201+
def _affine_bounding_box_xyxy(
201202
bounding_box: torch.Tensor,
202-
format: features.BoundingBoxFormat,
203203
image_size: Tuple[int, int],
204204
angle: float,
205-
translate: List[float],
206-
scale: float,
207-
shear: List[float],
205+
translate: Optional[List[float]] = None,
206+
scale: Optional[float] = None,
207+
shear: Optional[List[float]] = None,
208208
center: Optional[List[float]] = None,
209+
expand: bool = False,
209210
) -> torch.Tensor:
210-
original_shape = bounding_box.shape
211-
bounding_box = convert_bounding_box_format(
212-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
213-
).view(-1, 4)
214-
215211
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
216212
device = bounding_box.device
217213

214+
if translate is None:
215+
translate = [0.0, 0.0]
216+
217+
if scale is None:
218+
scale = 1.0
219+
220+
if shear is None:
221+
shear = [0.0, 0.0]
222+
218223
if center is None:
219224
height, width = image_size
220225
center_f = [width * 0.5, height * 0.5]
@@ -241,6 +246,47 @@ def affine_bounding_box(
241246
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
242247
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
243248
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
249+
250+
if expand:
251+
# Compute minimum point for transformed image frame:
252+
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
253+
height, width = image_size
254+
points = torch.tensor(
255+
[
256+
[0.0, 0.0, 1.0],
257+
[0.0, 1.0 * height, 1.0],
258+
[1.0 * width, 1.0 * height, 1.0],
259+
[1.0 * width, 0.0, 1.0],
260+
],
261+
dtype=dtype,
262+
device=device,
263+
)
264+
new_points = torch.matmul(points, affine_matrix.T)
265+
tr, _ = torch.min(new_points, dim=0, keepdim=True)
266+
# Translate bounding boxes
267+
out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
268+
out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]
269+
270+
return out_bboxes
271+
272+
273+
def affine_bounding_box(
274+
bounding_box: torch.Tensor,
275+
format: features.BoundingBoxFormat,
276+
image_size: Tuple[int, int],
277+
angle: float,
278+
translate: List[float],
279+
scale: float,
280+
shear: List[float],
281+
center: Optional[List[float]] = None,
282+
) -> torch.Tensor:
283+
original_shape = bounding_box.shape
284+
bounding_box = convert_bounding_box_format(
285+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
286+
).view(-1, 4)
287+
288+
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center)
289+
244290
# out_bboxes should be of shape [N boxes, 4]
245291

246292
return convert_bounding_box_format(
@@ -277,9 +323,12 @@ def rotate_image_tensor(
277323
) -> torch.Tensor:
278324
center_f = [0.0, 0.0]
279325
if center is not None:
280-
_, height, width = get_dimensions_image_tensor(img)
281-
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
282-
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
326+
if expand:
327+
warnings.warn("The provided center argument is ignored if expand is True")
328+
else:
329+
_, height, width = get_dimensions_image_tensor(img)
330+
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
331+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
283332

284333
# due to current incoherence of rotation angle direction between affine and rotate implementations
285334
# we need to set -angle.
@@ -295,11 +344,35 @@ def rotate_image_pil(
295344
fill: Optional[List[float]] = None,
296345
center: Optional[List[float]] = None,
297346
) -> PIL.Image.Image:
347+
if center is not None and expand:
348+
warnings.warn("The provided center argument is ignored if expand is True")
349+
center = None
350+
298351
return _FP.rotate(
299352
img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
300353
)
301354

302355

356+
def rotate_bounding_box(
357+
bounding_box: torch.Tensor,
358+
format: features.BoundingBoxFormat,
359+
image_size: Tuple[int, int],
360+
angle: float,
361+
expand: bool = False,
362+
center: Optional[List[float]] = None,
363+
) -> torch.Tensor:
364+
original_shape = bounding_box.shape
365+
bounding_box = convert_bounding_box_format(
366+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
367+
).view(-1, 4)
368+
369+
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand)
370+
371+
return convert_bounding_box_format(
372+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
373+
).view(original_shape)
374+
375+
303376
pad_image_tensor = _FT.pad
304377
pad_image_pil = _FP.pad
305378

0 commit comments

Comments
 (0)