diff --git a/test/common_utils.py b/test/common_utils.py index a74f204f429..df3126fa4b4 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -417,6 +417,13 @@ def sample_position(values, max_value): format = tv_tensors.BoundingBoxFormat[format] dtype = dtype or torch.float32 + int_dtype = dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ) h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size) y = sample_position(h, canvas_size[0]) @@ -443,17 +450,17 @@ def sample_position(values, max_value): elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY: r_rad = r * torch.pi / 180.0 cos, sin = torch.cos(r_rad), torch.sin(r_rad) - x1, y1 = x, y - x2 = x1 + w * cos - y2 = y1 - w * sin - x3 = x2 + h * sin - y3 = y2 + h * cos - x4 = x1 + h * sin - y4 = y1 + h * cos + x1 = torch.round(x) if int_dtype else x + y1 = torch.round(y) if int_dtype else y + x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos + y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin + x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin + y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos + x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin + y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos parts = (x1, y1, x2, y2, x3, y3, x4, y4) else: raise ValueError(f"Format {format} is not supported") - return tv_tensors.BoundingBoxes( torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size ) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 21e81ec37f8..7ade10a8ea5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -49,7 +49,7 @@ from torchvision.transforms.functional import pil_modes_mapping, to_pil_image from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import check_type, is_pure_tensor -from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs +from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal @@ -560,7 +560,9 @@ def affine_bounding_boxes(bounding_boxes): ) -def reference_affine_rotated_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True): +def reference_affine_rotated_bounding_boxes_helper( + bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True, flip=False +): format = bounding_boxes.format canvas_size = new_canvas_size or bounding_boxes.canvas_size @@ -588,21 +590,34 @@ def affine_rotated_bounding_boxes(bounding_boxes): transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T) output = torch.tensor( [ - float(transformed_points[1, 0]), - float(transformed_points[1, 1]), float(transformed_points[0, 0]), float(transformed_points[0, 1]), - float(transformed_points[3, 0]), - float(transformed_points[3, 1]), + float(transformed_points[1, 0]), + float(transformed_points[1, 1]), float(transformed_points[2, 0]), float(transformed_points[2, 1]), + float(transformed_points[3, 0]), + float(transformed_points[3, 1]), ] ) + output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output + output = _parallelogram_to_bounding_boxes(output) + output = F.convert_bounding_box_format( output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format ) + if torch.is_floating_point(output) and dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ): + # it is better to round before cast + output = torch.round(output) + if clamp: # It is important to clamp before casting, especially for CXCYWHR format, dtype=int64 output = F.clamp_bounding_boxes( @@ -707,7 +722,7 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype, check_scripted_vs_eager=not isinstance(size, int), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @@ -725,6 +740,7 @@ def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device): check_kernel( F.resize_bounding_boxes, bounding_boxes, + format=format, canvas_size=bounding_boxes.canvas_size, size=size, **max_size_kwarg, @@ -816,7 +832,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn): self._check_output_size(image, actual, size=size, **max_size_kwarg) torch.testing.assert_close(actual, expected, atol=1, rtol=0) - def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None): + def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None): old_height, old_width = bounding_boxes.canvas_size new_height, new_width = self._compute_output_size( input_size=bounding_boxes.canvas_size, size=size, max_size=max_size @@ -832,13 +848,19 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non ], ) - return reference_affine_bounding_boxes_helper( + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper + ) + + return helper( bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(new_height, new_width), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("use_max_size", [True, False]) @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) @@ -849,7 +871,7 @@ def test_bounding_boxes_correctness(self, format, size, use_max_size, fn): bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE) actual = fn(bounding_boxes, size=size, **max_size_kwarg) - expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg) + expected = self._reference_resize_bounding_boxes(bounding_boxes, format=format, size=size, **max_size_kwarg) self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg) torch.testing.assert_close(actual, expected) @@ -1152,7 +1174,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B ) helper = ( - reference_affine_rotated_bounding_boxes_helper + functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True) if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) else reference_affine_bounding_boxes_helper ) @@ -1257,7 +1279,7 @@ def test_kernel_image(self, param, value, dtype, device): shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"], center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"], ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_boxes(self, param, value, format, dtype, device): @@ -1399,14 +1421,22 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, if center is None: center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]] - return reference_affine_bounding_boxes_helper( + affine_matrix = self._compute_affine_matrix( + angle=angle, translate=translate, scale=scale, shear=shear, center=center + ) + + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper + ) + + return helper( bounding_boxes, - affine_matrix=self._compute_affine_matrix( - angle=angle, translate=translate, scale=scale, shear=shear, center=center - ), + affine_matrix=affine_matrix, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"]) @@ -1607,7 +1637,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou ) helper = ( - reference_affine_rotated_bounding_boxes_helper + functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True) if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) else reference_affine_bounding_boxes_helper ) @@ -2914,7 +2944,7 @@ def test_kernel_image(self, kwargs, dtype, device): check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_bounding_box(self, kwargs, format, dtype, device): @@ -3059,12 +3089,15 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w [0, 1, -top], ], ) - return reference_affine_bounding_boxes_helper( - bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width) + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper ) + return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device): @@ -3077,7 +3110,7 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device assert_equal(F.get_size(actual), F.get_size(expected)) @pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)]) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seed", list(range(5))) @@ -3099,7 +3132,7 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, expected = self._reference_crop_bounding_boxes(bounding_boxes, **params) - assert_equal(actual, expected) + torch.testing.assert_close(actual, expected) assert_equal(F.get_size(actual), F.get_size(expected)) def test_errors(self): @@ -3834,13 +3867,19 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h ) affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :] - return reference_affine_bounding_boxes_helper( + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper + ) + + return helper( bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=size, ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional_bounding_boxes_correctness(self, format): bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) @@ -3849,7 +3888,7 @@ def test_functional_bounding_boxes_correctness(self, format): bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE ) - assert_equal(actual, expected) + torch.testing.assert_close(actual, expected) assert_equal(F.get_size(actual), F.get_size(expected)) def test_transform_errors_warnings(self): @@ -3914,7 +3953,7 @@ def test_kernel_image(self, param, value, dtype, device): ), ) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_kernel_bounding_boxes(self, format): bounding_boxes = make_bounding_boxes(format=format) check_kernel( @@ -4034,12 +4073,15 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding): height = bounding_boxes.canvas_size[0] + top + bottom width = bounding_boxes.canvas_size[1] + left + right - return reference_affine_bounding_boxes_helper( - bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width) + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper ) + return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)) @pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)]) @@ -4049,7 +4091,7 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn): actual = fn(bounding_boxes, padding=padding) expected = self._reference_pad_bounding_boxes(bounding_boxes, padding=padding) - assert_equal(actual, expected) + torch.testing.assert_close(actual, expected) class TestCenterCrop: @@ -4068,7 +4110,7 @@ def test_kernel_image(self, output_size, dtype, device): ) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_kernel_bounding_boxes(self, output_size, format): bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) check_kernel( @@ -4142,12 +4184,15 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): [0, 1, -top], ], ) - return reference_affine_bounding_boxes_helper( - bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=output_size + helper = ( + reference_affine_rotated_bounding_boxes_helper + if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) + else reference_affine_bounding_boxes_helper ) + return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=output_size) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) - @pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS) + @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) @@ -4157,7 +4202,7 @@ def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn actual = fn(bounding_boxes, output_size) expected = self._reference_center_crop_bounding_boxes(bounding_boxes, output_size) - assert_equal(actual, expected) + torch.testing.assert_close(actual, expected) class TestPerspective: @@ -5894,6 +5939,37 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t assert out_label == label +@pytest.mark.parametrize("input_size", [(17, 11), (11, 17), (11, 11)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) +@pytest.mark.parametrize("device", cpu_and_cuda()) +def test_parallelogram_to_bounding_boxes(input_size, dtype, device): + # Assert that applying `_parallelogram_to_bounding_boxes` to rotated boxes + # does not modify the input. + bounding_boxes = make_bounding_boxes( + input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device + ) + actual = _parallelogram_to_bounding_boxes(bounding_boxes) + torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1) + + # Test the transformation of two simple parallelograms. + # 1---2 1----2 + # / / -> | | + # 4---3 4----3 + + # 1---2 1----2 + # \ \ -> | | + # 4---3 4----3 + parallelogram = torch.tensor([[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]]) + expected = torch.tensor( + [ + [0, 0, 4, 0, 4, 2, 0, 2], + [0, 0, 4, 0, 4, 2, 0, 2], + ] + ) + actual = _parallelogram_to_bounding_boxes(parallelogram) + assert_equal(actual, expected) + + @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image)) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage)) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 34a6b3692b9..09be745a695 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -381,12 +381,84 @@ def _resize_mask_dispatch( return tv_tensors.wrap(output, like=inpt) +def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tensor: + """ + Convert a parallelogram to a rectangle while keeping two points unchanged. + This function transforms a parallelogram represented by 8 coordinates (4 points) into a rectangle. + The two diagonally opposed points of the parallelogram forming the longest diagonal remain fixed. + The other points are adjusted to form a proper rectangle. + + Note: + This function is not applied in-place and will return a copy of the input tensor. + + Args: + parallelogram (torch.Tensor): Tensor of shape (..., 8) containing coordinates of parallelograms. + Format is [x1, y1, x2, y2, x3, y3, x4, y4]. + + Returns: + torch.Tensor: Tensor of same shape as input containing the rectangle coordinates. + The output maintains the same dtype as the input. + """ + dtype = parallelogram.dtype + int_dtype = dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ) + + out_boxes = parallelogram.clone() + + # Calculate parallelogram diagonal vectors + dx13 = parallelogram[..., 4] - parallelogram[..., 0] + dy13 = parallelogram[..., 5] - parallelogram[..., 1] + dx42 = parallelogram[..., 2] - parallelogram[..., 6] + dy42 = parallelogram[..., 3] - parallelogram[..., 7] + diag13 = torch.sqrt(dx13**2 + dy13**2) + diag24 = torch.sqrt(dx42**2 + dy42**2) + mask = diag13 > diag24 + + # Calculate rotation angle in radians + r_rad = torch.atan2(parallelogram[..., 1] - parallelogram[..., 3], parallelogram[..., 2] - parallelogram[..., 0]) + cos, sin = torch.cos(r_rad), torch.sin(r_rad) + + # Calculate width using the angle between diagonal and rotation + w = torch.where( + mask, + diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad)), + diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)), + ) + + delta_x = torch.round(w * cos).to(dtype) if int_dtype else w * cos + delta_y = torch.round(w * sin).to(dtype) if int_dtype else w * sin + + # Update coordinates to form a rectangle + # Keeping the points (x1, y1) and (x3, y3) unchanged. + out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2]) + out_boxes[..., 3] = torch.where(mask, parallelogram[..., 1] - delta_y, parallelogram[..., 3]) + out_boxes[..., 6] = torch.where(mask, parallelogram[..., 4] - delta_x, parallelogram[..., 6]) + out_boxes[..., 7] = torch.where(mask, parallelogram[..., 5] + delta_y, parallelogram[..., 7]) + + # Keeping the points (x2, y2) and (x4, y4) unchanged. + out_boxes[..., 0] = torch.where(~mask, parallelogram[..., 2] - delta_x, parallelogram[..., 0]) + out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1]) + out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4]) + out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5]) + return out_boxes + + def resize_bounding_boxes( bounding_boxes: torch.Tensor, canvas_size: tuple[int, int], size: Optional[list[int]], max_size: Optional[int] = None, + format: tv_tensors.BoundingBoxFormat = tv_tensors.BoundingBoxFormat.XYXY, ) -> tuple[torch.Tensor, tuple[int, int]]: + # We set the default format as `tv_tensors.BoundingBoxFormat.XYXY` + # to ensure backward compatibility. + # Indeed before the introduction of rotated bounding box format + # this function did not received `format` parameter as input. old_height, old_width = canvas_size new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size) @@ -395,11 +467,34 @@ def resize_bounding_boxes( w_ratio = new_width / old_width h_ratio = new_height / old_height - ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device) - return ( - bounding_boxes.mul(ratios).to(bounding_boxes.dtype), - (new_height, new_width), - ) + if tv_tensors.is_rotated_bounding_format(format): + original_shape = bounding_boxes.shape + xyxyxyxy_boxes = convert_bounding_box_format( + bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=False + ).reshape(-1, 8) + + ratios = torch.tensor( + [w_ratio, h_ratio, w_ratio, h_ratio, w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device + ) + transformed_points = xyxyxyxy_boxes.mul(ratios) + out_bboxes = _parallelogram_to_bounding_boxes(transformed_points) + return ( + convert_bounding_box_format( + out_bboxes, + old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, + new_format=format, + inplace=False, + ) + .to(bounding_boxes.dtype) + .reshape(original_shape), + (new_height, new_width), + ) + else: + ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device) + return ( + bounding_boxes.mul(ratios).to(bounding_boxes.dtype), + (new_height, new_width), + ) @_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @@ -407,7 +502,7 @@ def _resize_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, size: Optional[list[int]], max_size: Optional[int] = None, **kwargs: Any ) -> tv_tensors.BoundingBoxes: output, canvas_size = resize_bounding_boxes( - inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, size=size, max_size=max_size ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -825,11 +920,12 @@ def _affine_bounding_boxes_with_expand( bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() dtype = bounding_boxes.dtype device = bounding_boxes.device + is_rotated = tv_tensors.is_rotated_bounding_format(format) + intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY + intermediate_shape = 8 if is_rotated else 4 bounding_boxes = ( - convert_bounding_box_format( - bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True - ) - ).reshape(-1, 4) + convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format, inplace=True) + ).reshape(-1, intermediate_shape) angle, translate, shear, center = _affine_parse_args( angle, translate, scale, shear, InterpolationMode.NEAREST, center @@ -853,15 +949,22 @@ def _affine_bounding_boxes_with_expand( # Tensor of points has shape (N * 4, 3), where N is the number of bboxes # Single point structure is similar to # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] - points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + if is_rotated: + points = bounding_boxes.reshape(-1, 2) + else: + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) # 2) Now let's transform the points using affine matrix transformed_points = torch.matmul(points, transposed_affine_matrix) # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # and compute bounding box from 4 transformed points: - transformed_points = transformed_points.reshape(-1, 4, 2) - out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) - out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + if is_rotated: + transformed_points = transformed_points.reshape(-1, 8) + out_bboxes = _parallelogram_to_bounding_boxes(transformed_points) + else: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) if expand: # Compute minimum point for transformed image frame: @@ -886,9 +989,9 @@ def _affine_bounding_boxes_with_expand( new_width, new_height = _compute_affine_output_size(affine_vector, width, height) canvas_size = (new_height, new_width) - out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size) + out_bboxes = clamp_bounding_boxes(out_bboxes, format=intermediate_format, canvas_size=canvas_size) out_bboxes = convert_bounding_box_format( - out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True + out_bboxes, old_format=intermediate_format, new_format=format, inplace=True ).reshape(original_shape) out_bboxes = out_bboxes.to(original_dtype) @@ -1379,7 +1482,11 @@ def pad_bounding_boxes( left, right, top, bottom = _parse_pad_padding(padding) - if format == tv_tensors.BoundingBoxFormat.XYXY: + if format == tv_tensors.BoundingBoxFormat.XYXYXYXY: + pad = [left, top, left, top, left, top, left, top] + elif format == tv_tensors.BoundingBoxFormat.XYWHR or format == tv_tensors.BoundingBoxFormat.CXCYWHR: + pad = [left, top, 0, 0, 0] + elif format == tv_tensors.BoundingBoxFormat.XYXY: pad = [left, top, left, top] else: pad = [left, top, 0, 0] @@ -1462,7 +1569,11 @@ def crop_bounding_boxes( ) -> tuple[torch.Tensor, tuple[int, int]]: # Crop or implicit pad if left and/or top have negative values: - if format == tv_tensors.BoundingBoxFormat.XYXY: + if format == tv_tensors.BoundingBoxFormat.XYXYXYXY: + sub = [left, top, left, top, left, top, left, top] + elif format == tv_tensors.BoundingBoxFormat.XYWHR or format == tv_tensors.BoundingBoxFormat.CXCYWHR: + sub = [left, top, 0, 0, 0] + elif format == tv_tensors.BoundingBoxFormat.XYXY: sub = [left, top, left, top] else: sub = [left, top, 0, 0] @@ -1470,6 +1581,9 @@ def crop_bounding_boxes( bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device) canvas_size = (height, width) + if format == tv_tensors.BoundingBoxFormat.XYXYXYXY: + bounding_boxes = _parallelogram_to_bounding_boxes(bounding_boxes) + return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size @@ -2204,7 +2318,7 @@ def resized_crop_bounding_boxes( size: list[int], ) -> tuple[torch.Tensor, tuple[int, int]]: bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) - return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size) + return resize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, size=size) @_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 8dce16957d9..019f4e25cb7 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -263,14 +263,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxyxyxy[..., :5].to(dtype) -def is_rotated_bounding_box_format(format: BoundingBoxFormat) -> bool: - return format.value in [ - BoundingBoxFormat.XYWHR.value, - BoundingBoxFormat.CXCYWHR.value, - BoundingBoxFormat.XYXYXYXY.value, - ] - - def _convert_bounding_box_format( bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: @@ -278,7 +270,7 @@ def _convert_bounding_box_format( if new_format == old_format: return bounding_boxes - if is_rotated_bounding_box_format(old_format) ^ is_rotated_bounding_box_format(new_format): + if tv_tensors.is_rotated_bounding_format(old_format) ^ tv_tensors.is_rotated_bounding_format(new_format): raise ValueError("Cannot convert between rotated and unrotated bounding boxes.") # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance