diff --git a/test/common_utils.py b/test/common_utils.py index 9af40cec878..ee3a2d5cbde 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -21,7 +21,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image +from torchvision.transforms.v2.functional import to_image, to_pil_image IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) @@ -410,7 +410,7 @@ def make_bounding_boxes( canvas_size=DEFAULT_SIZE, *, format=tv_tensors.BoundingBoxFormat.XYXY, - clamping_mode="hard", # TODOBB + clamping_mode="soft", num_boxes=1, dtype=None, device="cpu", @@ -469,21 +469,6 @@ def sample_position(values, max_value): else: raise ValueError(f"Format {format} is not supported") out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device) - if tv_tensors.is_rotated_bounding_format(format): - # The rotated bounding boxes are not guaranteed to be within the canvas by design, - # so we apply clamping. We also add a 2 buffer to the canvas size to avoid - # numerical issues during the testing - buffer = 4 - out_boxes = clamp_bounding_boxes( - out_boxes, - format=format, - canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer), - clamping_mode=clamping_mode, - ) - if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR: - out_boxes[:, :2] += buffer // 2 - elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY: - out_boxes[:, :] += buffer // 2 return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index dd774672273..416b2e4facb 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -551,6 +551,7 @@ def affine_bounding_boxes(bounding_boxes): ), format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @@ -639,6 +640,7 @@ def affine_rotated_bounding_boxes(bounding_boxes): ).reshape(bounding_boxes.shape), format=format, canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @@ -1305,7 +1307,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) else reference_affine_bounding_boxes_helper ) - return helper(bounding_boxes, affine_matrix=affine_matrix) + return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize( @@ -1914,7 +1916,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou if tv_tensors.is_rotated_bounding_format(bounding_boxes.format) else reference_affine_bounding_boxes_helper ) - return helper(bounding_boxes, affine_matrix=affine_matrix) + return helper(bounding_boxes, affine_matrix=affine_matrix, clamp=False) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) @@ -2079,7 +2081,6 @@ def test_functional(self, make_input): (F.rotate_image, torch.Tensor), (F._geometry._rotate_image_pil, PIL.Image.Image), (F.rotate_image, tv_tensors.Image), - (F.rotate_bounding_boxes, tv_tensors.BoundingBoxes), (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), (F.rotate_keypoints, tv_tensors.KeyPoints), @@ -2229,29 +2230,26 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen clamp=False, ) - return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to( - bounding_boxes - ) + return self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy).to(bounding_boxes) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) def test_functional_bounding_boxes_correctness(self, format, angle, expand, center): - bounding_boxes = make_bounding_boxes(format=format) + bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none") actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center) expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center) - - torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + torch.testing.assert_close(actual, expected) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize("seed", list(range(5))) def test_transform_bounding_boxes_correctness(self, format, expand, center, seed): - bounding_boxes = make_bounding_boxes(format=format) + bounding_boxes = make_bounding_boxes(format=format, clamping_mode="none") transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) @@ -2262,9 +2260,8 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed actual = transform(bounding_boxes) expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center) - - torch.testing.assert_close(actual, expected) torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0) + torch.testing.assert_close(actual, expected) def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy): x, y = recenter_xy @@ -4349,7 +4346,6 @@ def test_functional(self, make_input): (F.resized_crop_image, torch.Tensor), (F._geometry._resized_crop_image_pil, PIL.Image.Image), (F.resized_crop_image, tv_tensors.Image), - (F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes), (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), (F.resized_crop_keypoints, tv_tensors.KeyPoints), @@ -4415,6 +4411,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h [0, 0, 1], ], ) + affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :] helper = ( @@ -4423,15 +4420,15 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h else reference_affine_bounding_boxes_helper ) - return helper( - bounding_boxes, - affine_matrix=affine_matrix, - new_canvas_size=size, - ) + return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=size, clamp=False) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) def test_functional_bounding_boxes_correctness(self, format): - bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format) + # Note that we don't want to clamp because in + # _reference_resized_crop_bounding_boxes we are fusing the crop and the + # resize operation, where none of the croppings happen - particularly, + # the intermediate one. + bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, clamping_mode="none") actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE) expected = self._reference_resized_crop_bounding_boxes( diff --git a/test/test_tv_tensors.py b/test/test_tv_tensors.py index 43efceba5c9..bed419b312c 100644 --- a/test/test_tv_tensors.py +++ b/test/test_tv_tensors.py @@ -406,3 +406,8 @@ def test_return_type_input(): tv_tensors.set_return_type("typo") tv_tensors.set_return_type("tensor") + + +def test_box_clamping_mode_default(): + assert tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft" + assert tv_tensors.BoundingBoxes([0, 0, 10, 10, 0], format="XYWHR", canvas_size=(100, 100)).clamping_mode == "soft" diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index aff96d0a7e8..f109247dc6b 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -522,7 +522,7 @@ def resize_bounding_boxes( size: Optional[list[int]], max_size: Optional[int] = None, format: tv_tensors.BoundingBoxFormat = tv_tensors.BoundingBoxFormat.XYXY, - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: # We set the default format as `tv_tensors.BoundingBoxFormat.XYXY` # to ensure backward compatibility. @@ -1108,15 +1108,16 @@ def _affine_bounding_boxes_with_expand( shear: list[float], center: Optional[list[float]] = None, expand: bool = False, - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: if bounding_boxes.numel() == 0: return bounding_boxes, canvas_size original_shape = bounding_boxes.shape dtype = bounding_boxes.dtype - need_cast = not bounding_boxes.is_floating_point() - bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone() + acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU. + need_cast = dtype not in acceptable_dtypes + bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone() device = bounding_boxes.device is_rotated = tv_tensors.is_rotated_bounding_format(format) intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY @@ -1210,7 +1211,7 @@ def affine_bounding_boxes( scale: float, shear: list[float], center: Optional[list[float]] = None, - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: out_box, _ = _affine_bounding_boxes_with_expand( bounding_boxes, @@ -1448,6 +1449,7 @@ def rotate_bounding_boxes( angle: float, expand: bool = False, center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: return _affine_bounding_boxes_with_expand( bounding_boxes, @@ -1459,6 +1461,7 @@ def rotate_bounding_boxes( shear=[0.0, 0.0], center=center, expand=expand, + clamping_mode=clamping_mode, ) @@ -1473,6 +1476,7 @@ def _rotate_bounding_boxes_dispatch( angle=angle, expand=expand, center=center, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1739,7 +1743,7 @@ def pad_bounding_boxes( canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant", - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: if padding_mode not in ["constant"]: # TODO: add support of other padding modes @@ -1857,7 +1861,7 @@ def crop_bounding_boxes( left: int, height: int, width: int, - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: # Crop or implicit pad if left and/or top have negative values: @@ -2097,7 +2101,7 @@ def perspective_bounding_boxes( startpoints: Optional[list[list[int]]], endpoints: Optional[list[list[int]]], coefficients: Optional[list[float]] = None, - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: if bounding_boxes.numel() == 0: return bounding_boxes @@ -2412,7 +2416,7 @@ def elastic_bounding_boxes( format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int], displacement: torch.Tensor, - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB soft + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: expected_shape = (1, canvas_size[0], canvas_size[1], 2) if not isinstance(displacement, torch.Tensor): @@ -2433,11 +2437,11 @@ def elastic_bounding_boxes( original_shape = bounding_boxes.shape # TODO: first cast to float if bbox is int64 before convert_bounding_box_format - intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY + intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY bounding_boxes = ( convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format) - ).reshape(-1, 8 if is_rotated else 4) + ).reshape(-1, 5 if is_rotated else 4) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement @@ -2445,7 +2449,7 @@ def elastic_bounding_boxes( inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] points = points.reshape(-1, 2) if points.is_floating_point(): points = points.ceil_() @@ -2457,8 +2461,8 @@ def elastic_bounding_boxes( transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) if is_rotated: - transformed_points = transformed_points.reshape(-1, 8) - out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype) + transformed_points = transformed_points.reshape(-1, 2) + out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype) else: transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) @@ -2619,11 +2623,18 @@ def center_crop_bounding_boxes( format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int], output_size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) return crop_bounding_boxes( - bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width + bounding_boxes, + format, + top=crop_top, + left=crop_left, + height=crop_height, + width=crop_width, + clamping_mode=clamping_mode, ) @@ -2632,7 +2643,11 @@ def _center_crop_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, output_size: list[int] ) -> tv_tensors.BoundingBoxes: output, canvas_size = center_crop_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + output_size=output_size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2779,9 +2794,14 @@ def resized_crop_bounding_boxes( height: int, width: int, size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: - bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) - return resize_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size, size=size) + bounding_boxes, canvas_size = crop_bounding_boxes( + bounding_boxes, format, top, left, height, width, clamping_mode=clamping_mode + ) + return resize_bounding_boxes( + bounding_boxes, format=format, canvas_size=canvas_size, size=size, clamping_mode=clamping_mode + ) @_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @@ -2789,7 +2809,14 @@ def _resized_crop_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs ) -> tv_tensors.BoundingBoxes: output, canvas_size = resized_crop_bounding_boxes( - inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size + inpt.as_subclass(torch.Tensor), + format=inpt.format, + top=top, + left=left, + height=height, + width=width, + size=size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index cf23471f770..4cc3c2f3f8e 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -374,7 +374,7 @@ def _clamp_bounding_boxes( bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int], - clamping_mode: Optional[CLAMPING_MODE_TYPE], # TODOBB shouldn't be Optional + clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: if clamping_mode is not None and clamping_mode == "none": return bounding_boxes.clone() @@ -385,6 +385,7 @@ def _clamp_bounding_boxes( xyxy_boxes = convert_bounding_box_format( bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True ) + # hard and soft modes are equivalent for non-rotated boxes xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0]) out_boxes = convert_bounding_box_format( @@ -415,23 +416,88 @@ def _order_bounding_boxes_points( if indices is None: output_xyxyxyxy = bounding_boxes.reshape(-1, 8) x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2] - y_max = torch.max(y, dim=1, keepdim=True)[0] - _, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1) + y_max = torch.max(y.abs(), dim=1, keepdim=True)[0] + x_max = torch.max(x.abs(), dim=1, keepdim=True)[0] + _, x1 = (y / y_max + (x / x_max) * 100).min(dim=1) indices = torch.ones_like(output_xyxyxyxy) indices[..., 0] = x1.mul(2) indices.cumsum_(1).remainder_(8) return indices, bounding_boxes.gather(1, indices.to(torch.int64)) -def _area(box: torch.Tensor) -> torch.Tensor: - x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1) - w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2) - h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2) - return w * h +def _get_slope_and_intercept(box: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the slope and y-intercept of the lines defined by consecutive vertices in a bounding box. + This function computes the slope (a) and y-intercept (b) for each line segment in a bounding box, + where each line is defined by two consecutive vertices. + """ + x, y = box[..., ::2], box[..., 1::2] + a = y.diff(append=y[..., 0:1]) / x.diff(append=x[..., 0:1]) + b = y - a * x + return a, b + + +def _get_intersection_point(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Calculate the intersection point of two lines defined by their slopes and y-intercepts. + This function computes the intersection points between pairs of lines, where each line + is defined by the equation y = ax + b (slope and y-intercept form). + """ + batch_size = a.shape[0] + x = b.diff(prepend=b[..., 3:4]).neg() / a.diff(prepend=a[..., 3:4]) + y = a * x + b + return torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=-1).view(batch_size, 8) + + +def _clamp_y_intercept( + bounding_boxes: torch.Tensor, + original_bounding_boxes: torch.Tensor, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE, +) -> torch.Tensor: + """ + Apply clamping to bounding box y-intercepts. This function handles two clamping strategies: + - Hard clamping: Ensures all box vertices stay within canvas boundaries, finding the largest + angle-preserving box enclosed within the original box and the image canvas. + - Soft clamping: Allows some vertices to extend beyond the canvas, finding the smallest + angle-preserving box that encloses the intersection of the original box and the image canvas. + + The function first calculates the slopes and y-intercepts of the lines forming the bounding box, + then applies various constraints to ensure the clamping conditions are respected. + """ + + a, b = _get_slope_and_intercept(bounding_boxes) + a1, a2, a3, a4 = a.unbind(-1) + b1, b2, b3, b4 = b.unbind(-1) + + # Clamp y-intercepts (soft clamping) + b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0]) + b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0]) + + if clamping_mode == "hard": + # Get y-intercepts from original bounding boxes + _, b = _get_slope_and_intercept(original_bounding_boxes) + _, b2, b3, _ = b.unbind(-1) + + # Set b1 and b4 to the average of their clamped values + b1 = b4 = (b1.clamp(0, canvas_size[0]) + b4.clamp(0, canvas_size[0])) / 2 + + # Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4 + b2.clamp_(b1 * a2 / a1, b4).clamp_((a1 - a2) * canvas_size[1] + b1) + b2.clamp_(b3 * a2 / a3, b4).clamp_((a3 - a2) * canvas_size[1] + b3) + b3.clamp_(max=canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4) + b3.clamp_(max=canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2) + b3.clamp_(b1, (a2 - a3) * canvas_size[1] + b2) + b3.clamp_(b1, (a4 - a3) * canvas_size[1] + b4) + + return torch.stack([b1, b2, b3, b4], dim=-1) def _clamp_along_y_axis( bounding_boxes: torch.Tensor, + original_bounding_boxes: torch.Tensor, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: """ Adjusts bounding boxes along the y-axis based on specific conditions. @@ -442,51 +508,62 @@ def _clamp_along_y_axis( Args: bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates. + original_bounding_boxes (torch.Tensor): The original bounding boxes before any clamping is applied. + canvas_size (tuple[int, int]): The size of the canvas as (height, width). + clamping_mode (str, optional): The clamping strategy to use. Returns: torch.Tensor: The adjusted bounding boxes. """ - original_dtype = bounding_boxes.dtype + dtype = bounding_boxes.dtype + acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU. + need_cast = dtype not in acceptable_dtypes + eps = 1e-06 # Ensure consistency between CPU and GPU. original_shape = bounding_boxes.shape - x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.reshape(-1, 8).unbind(-1) - a = (y2 - y1) / (x2 - x1) - b1 = y1 - a * x1 - b2 = y2 + x2 / a - b3 = y3 - a * x3 - b4 = y4 + x4 / a - b23 = (b2 - b3) / 2 * a / (1 + a**2) - z = torch.zeros_like(b1) - case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1) - case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1) - case_c = torch.cat( - [x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1 - ) - case_d = torch.zeros_like(case_c) - case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1) - - cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0) - cond_a = cond_a.logical_and(_area(case_a) > _area(case_b)) - cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)) - cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0) - cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b)) - cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)) - cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0) - cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0) - cond_e = x1.isclose(x2) - - for cond, case in zip( - [cond_a, cond_b, cond_c, cond_d, cond_e], - [case_a, case_b, case_c, case_d, case_e], + bounding_boxes = bounding_boxes.reshape(-1, 8) + original_bounding_boxes = original_bounding_boxes.reshape(-1, 8) + + # Calculate slopes (a) and y-intercepts (b) for all lines in the bounding boxes + a, b = _get_slope_and_intercept(bounding_boxes) + x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.unbind(-1) + b = _clamp_y_intercept(bounding_boxes, original_bounding_boxes, canvas_size, clamping_mode) + + case_a = _get_intersection_point(a, b) + case_b = bounding_boxes.clone() + case_b[..., 0].clamp_(0) # Clamp x1 to 0 + case_b[..., 6].clamp_(0) # Clamp x4 to 0 + case_c = torch.zeros_like(case_b) + + cond_a = (x1 < eps) & ~case_a.isnan().any(-1) # First point is outside left boundary + cond_b = y1.isclose(y2, rtol=eps, atol=eps) | y3.isclose(y4, rtol=eps, atol=eps) # First line is nearly vertical + cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary + cond_c = ( + cond_c + | y1.isclose(y4, rtol=eps, atol=eps) + | y2.isclose(y3, rtol=eps, atol=eps) + | (cond_b & x1.isclose(x2, rtol=eps, atol=eps)) + ) # First line is nearly horizontal + + for (cond, case) in zip( + [cond_a, cond_b, cond_c], + [case_a, case_b, case_c], ): bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes) - return bounding_boxes.to(original_dtype).reshape(original_shape) + if clamping_mode == "hard": + bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0 + + if need_cast: + if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + bounding_boxes.round_() + bounding_boxes = bounding_boxes.to(dtype) + return bounding_boxes.reshape(original_shape) def _clamp_rotated_bounding_boxes( bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int], - clamping_mode: Optional[CLAMPING_MODE_TYPE], # TODOBB shouldn't be Optional + clamping_mode: CLAMPING_MODE_TYPE, ) -> torch.Tensor: """ Clamp rotated bounding boxes to ensure they stay within the canvas boundaries. @@ -521,15 +598,22 @@ def _clamp_rotated_bounding_boxes( ) ).reshape(-1, 8) + original_boxes = out_boxes.clone() for _ in range(4): # Iterate over the 4 vertices. indices, out_boxes = _order_bounding_boxes_points(out_boxes) - out_boxes = _clamp_along_y_axis(out_boxes) + _, original_boxes = _order_bounding_boxes_points(original_boxes, indices) + out_boxes = _clamp_along_y_axis(out_boxes, original_boxes, canvas_size, clamping_mode) _, out_boxes = _order_bounding_boxes_points(out_boxes, indices) + _, original_boxes = _order_bounding_boxes_points(original_boxes, indices) # rotate 90 degrees counter clock wise out_boxes[:, ::2], out_boxes[:, 1::2] = ( out_boxes[:, 1::2].clone(), canvas_size[1] - out_boxes[:, ::2].clone(), ) + original_boxes[:, ::2], original_boxes[:, 1::2] = ( + original_boxes[:, 1::2].clone(), + canvas_size[1] - original_boxes[:, ::2].clone(), + ) canvas_size = (canvas_size[1], canvas_size[0]) out_boxes = convert_bounding_box_format( @@ -538,7 +622,8 @@ def _clamp_rotated_bounding_boxes( if need_cast: if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - out_boxes.round_() + # Adding epsilon to ensure consistency between CPU and GPU rounding. + out_boxes.add_(1e-7).round_() out_boxes = out_boxes.to(dtype) return out_boxes diff --git a/torchvision/tv_tensors/_bounding_boxes.py b/torchvision/tv_tensors/_bounding_boxes.py index 22a32b7dfa5..72a2825aad1 100644 --- a/torchvision/tv_tensors/_bounding_boxes.py +++ b/torchvision/tv_tensors/_bounding_boxes.py @@ -105,7 +105,7 @@ def __new__( *, format: BoundingBoxFormat | str, canvas_size: tuple[int, int], - clamping_mode: CLAMPING_MODE_TYPE = "hard", # TODOBB change default to soft! + clamping_mode: CLAMPING_MODE_TYPE = "soft", dtype: torch.dtype | None = None, device: torch.device | str | int | None = None, requires_grad: bool | None = None,