Skip to content

Adjust clamping for rotated bboxes #9112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image
from torchvision.transforms.v2.functional import to_image, to_pil_image


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
Expand Down Expand Up @@ -468,18 +468,6 @@ def sample_position(values, max_value):
else:
raise ValueError(f"Format {format} is not supported")
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
if tv_tensors.is_rotated_bounding_format(format):
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
# numerical issues during the testing
buffer = 4
out_boxes = clamp_bounding_boxes(
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
)
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
out_boxes[:, :2] += buffer // 2
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
out_boxes[:, :] += buffer // 2
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size)


Expand Down
25 changes: 16 additions & 9 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B
)

helper = (
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True, clamp=False)
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)
Expand Down Expand Up @@ -1907,7 +1907,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou
)

helper = (
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True, clamp=False)
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)
Expand Down Expand Up @@ -2196,7 +2196,7 @@ def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
(bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
)

def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center, canvas_size=None):
if center is None:
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
cx, cy = center
Expand All @@ -2222,7 +2222,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
output = helper(
bounding_boxes,
affine_matrix=affine_matrix,
new_canvas_size=new_canvas_size,
new_canvas_size=new_canvas_size if canvas_size is None else canvas_size,
clamp=False,
)

Expand All @@ -2239,9 +2239,12 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent

actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

expected = self._reference_rotate_bounding_boxes(
bounding_boxes, angle=angle, expand=expand, center=center, canvas_size=actual.canvas_size
)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("expand", [False, True])
Expand All @@ -2259,9 +2262,12 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed
actual = transform(bounding_boxes)

expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

expected = self._reference_rotate_bounding_boxes(
bounding_boxes, **params, expand=expand, center=center, canvas_size=actual.canvas_size
)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

def _recenter_keypoints_after_expand(self, keypoints, *, recenter_xy):
x, y = recenter_xy
Expand Down Expand Up @@ -4413,17 +4419,18 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
[0, 0, 1],
],
)
affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]

helper = (
reference_affine_rotated_bounding_boxes_helper
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)

bounding_boxes = helper(bounding_boxes, affine_matrix=crop_affine_matrix, new_canvas_size=(height, width))

return helper(
bounding_boxes,
affine_matrix=affine_matrix,
affine_matrix=resize_affine_matrix,
new_canvas_size=size,
)

Expand All @@ -4436,7 +4443,7 @@ def test_functional_bounding_boxes_correctness(self, format):
bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE
)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5)
assert_equal(F.get_size(actual), F.get_size(expected))

def _reference_resized_crop_keypoints(self, keypoints, *, top, left, height, width, size):
Expand Down
15 changes: 8 additions & 7 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,9 @@ def _affine_bounding_boxes_with_expand(

original_shape = bounding_boxes.shape
dtype = bounding_boxes.dtype
need_cast = not bounding_boxes.is_floating_point()
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
need_cast = dtype not in acceptable_dtypes
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
device = bounding_boxes.device
is_rotated = tv_tensors.is_rotated_bounding_format(format)
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
Expand Down Expand Up @@ -2397,19 +2398,19 @@ def elastic_bounding_boxes(

original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY

bounding_boxes = (
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
).reshape(-1, 8 if is_rotated else 4)
).reshape(-1, 5 if is_rotated else 4)

id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)

# Get points from bboxes
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
points = points.reshape(-1, 2)
if points.is_floating_point():
points = points.ceil_()
Expand All @@ -2421,8 +2422,8 @@ def elastic_bounding_boxes(
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)

if is_rotated:
transformed_points = transformed_points.reshape(-1, 8)
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype)
transformed_points = transformed_points.reshape(-1, 2)
out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype)
else:
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
Expand Down
166 changes: 125 additions & 41 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,23 +409,88 @@ def _order_bounding_boxes_points(
if indices is None:
output_xyxyxyxy = bounding_boxes.reshape(-1, 8)
x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2]
y_max = torch.max(y, dim=1, keepdim=True)[0]
_, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1)
y_max = torch.max(y.abs(), dim=1, keepdim=True)[0]
x_max = torch.max(x.abs(), dim=1, keepdim=True)[0]
_, x1 = (y / y_max + (x / x_max) * 100).min(dim=1)
indices = torch.ones_like(output_xyxyxyxy)
indices[..., 0] = x1.mul(2)
indices.cumsum_(1).remainder_(8)
return indices, bounding_boxes.gather(1, indices.to(torch.int64))


def _area(box: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1)
w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2)
h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2)
return w * h
def _get_slope_and_intercept(box: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the slope and y-intercept of the lines defined by consecutive vertices in a bounding box.
This function computes the slope (a) and y-intercept (b) for each line segment in a bounding box,
where each line is defined by two consecutive vertices.
"""
x, y = box[..., ::2], box[..., 1::2]
a = y.diff(append=y[..., 0:1]) / x.diff(append=x[..., 0:1])
b = y - a * x
return a, b


def _get_intersection_point(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Calculate the intersection point of two lines defined by their slopes and y-intercepts.
This function computes the intersection points between pairs of lines, where each line
is defined by the equation y = ax + b (slope and y-intercept form).
"""
batch_size = a.shape[0]
x = b.diff(prepend=b[..., 3:4]).neg() / a.diff(prepend=a[..., 3:4])
y = a * x + b
return torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), dim=-1).view(batch_size, 8)


def _clamp_y_intercept(
bounding_boxes: torch.Tensor,
original_bounding_boxes: torch.Tensor,
canvas_size: tuple[int, int],
clamping_mode: str = "hard",
) -> torch.Tensor:
"""
Apply clamping to bounding box y-intercepts. This function handles two clamping strategies:
- Hard clamping: Ensures all box vertices stay within canvas boundaries, finding the largest
angle-preserving box enclosed within the original box and the image canvas.
- Soft clamping: Allows some vertices to extend beyond the canvas, finding the smallest
angle-preserving box that encloses the intersection of the original box and the image canvas.

The function first calculates the slopes and y-intercepts of the lines forming the bounding box,
then applies various constraints to ensure the clamping conditions are respected.
"""

a, b = _get_slope_and_intercept(bounding_boxes)
a1, a2, a3, a4 = a.unbind(-1)
b1, b2, b3, b4 = b.unbind(-1)

# Clamp y-intercepts (soft clamping)
b1 = b2.clamp(b1, b3).clamp(0, canvas_size[0])
b4 = b3.clamp(b2, b4).clamp(0, canvas_size[0])

if clamping_mode == "hard":
# Get y-intercepts from original bounding boxes
_, b = _get_slope_and_intercept(original_bounding_boxes)
_, b2, b3, _ = b.unbind(-1)

# Set b1 and b4 to the average of their clamped values
b1 = b4 = (b1.clamp(0, canvas_size[0]) + b4.clamp(0, canvas_size[0])) / 2

# Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4
b2.clamp_(b1 * a2 / a1, b4).clamp_((a1 - a2) * canvas_size[1] + b1)
b2.clamp_(b3 * a2 / a3, b4).clamp_((a3 - a2) * canvas_size[1] + b3)
b3.clamp_(max=canvas_size[0] * (1 - a3 / a4) + b4 * a3 / a4)
b3.clamp_(max=canvas_size[0] * (1 - a3 / a2) + b2 * a3 / a2)
b3.clamp_(b1, (a2 - a3) * canvas_size[1] + b2)
b3.clamp_(b1, (a4 - a3) * canvas_size[1] + b4)

return torch.stack([b1, b2, b3, b4], dim=-1)


def _clamp_along_y_axis(
bounding_boxes: torch.Tensor,
original_bounding_boxes: torch.Tensor,
canvas_size: tuple[int, int],
clamping_mode: str = "hard",
) -> torch.Tensor:
"""
Adjusts bounding boxes along the y-axis based on specific conditions.
Expand All @@ -436,48 +501,59 @@ def _clamp_along_y_axis(

Args:
bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates.
original_bounding_boxes (torch.Tensor): The original bounding boxes before any clamping is applied.
canvas_size (tuple[int, int]): The size of the canvas as (height, width).
clamping_mode (str, optional): The clamping strategy to use. Defaults to "hard".

Returns:
torch.Tensor: The adjusted bounding boxes.
"""
original_dtype = bounding_boxes.dtype
dtype = bounding_boxes.dtype
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
need_cast = dtype not in acceptable_dtypes
eps = 1e-06 # Ensure consistency between CPU and GPU.
original_shape = bounding_boxes.shape
x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.reshape(-1, 8).unbind(-1)
a = (y2 - y1) / (x2 - x1)
b1 = y1 - a * x1
b2 = y2 + x2 / a
b3 = y3 - a * x3
b4 = y4 + x4 / a
b23 = (b2 - b3) / 2 * a / (1 + a**2)
z = torch.zeros_like(b1)
case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1)
case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1)
case_c = torch.cat(
[x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1
)
case_d = torch.zeros_like(case_c)
case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1)

cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
cond_a = cond_a.logical_and(_area(case_a) > _area(case_b))
cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0))
cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b))
cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0))
cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)
cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
cond_e = x1.isclose(x2)

for cond, case in zip(
[cond_a, cond_b, cond_c, cond_d, cond_e],
[case_a, case_b, case_c, case_d, case_e],
bounding_boxes = bounding_boxes.reshape(-1, 8)
original_bounding_boxes = original_bounding_boxes.reshape(-1, 8)

# Calculate slopes (a) and y-intercepts (b) for all lines in the bounding boxes
a, b = _get_slope_and_intercept(bounding_boxes)
x1, y1, x2, y2, x3, y3, x4, y4 = bounding_boxes.unbind(-1)
b = _clamp_y_intercept(bounding_boxes, original_bounding_boxes, canvas_size, clamping_mode)

case_a = _get_intersection_point(a, b)
case_b = bounding_boxes.clone()
case_b[..., 0].clamp_(0) # Clamp x1 to 0
case_b[..., 6].clamp_(0) # Clamp x4 to 0
case_c = torch.zeros_like(case_b)

cond_a = (x1 < eps) & ~case_a.isnan().any(-1) # First point is outside left boundary
cond_b = y1.isclose(y2, rtol=eps, atol=eps) | y3.isclose(y4, rtol=eps, atol=eps) # First line is nearly vertical
cond_c = (x1 <= 0) & (x2 <= 0) & (x3 <= 0) & (x4 <= 0) # All points outside left boundary
cond_c = (
cond_c
| y1.isclose(y4, rtol=eps, atol=eps)
| y2.isclose(y3, rtol=eps, atol=eps)
| (cond_b & x1.isclose(x2, rtol=eps, atol=eps))
) # First line is nearly horizontal

for (cond, case) in zip(
[cond_a, cond_b, cond_c],
[case_a, case_b, case_c],
):
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
return bounding_boxes.to(original_dtype).reshape(original_shape)
if clamping_mode == "hard":
bounding_boxes[..., 0].clamp_(0) # Clamp x1 to 0

if need_cast:
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
bounding_boxes.round_()
bounding_boxes = bounding_boxes.to(dtype)
return bounding_boxes.reshape(original_shape)


def _clamp_rotated_bounding_boxes(
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int]
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: tuple[int, int], clamping_mode: str = "soft"
) -> torch.Tensor:
"""
Clamp rotated bounding boxes to ensure they stay within the canvas boundaries.
Expand Down Expand Up @@ -510,15 +586,22 @@ def _clamp_rotated_bounding_boxes(
)
).reshape(-1, 8)

original_boxes = out_boxes.clone()
for _ in range(4): # Iterate over the 4 vertices.
indices, out_boxes = _order_bounding_boxes_points(out_boxes)
out_boxes = _clamp_along_y_axis(out_boxes)
_, original_boxes = _order_bounding_boxes_points(original_boxes, indices)
out_boxes = _clamp_along_y_axis(out_boxes, original_boxes, canvas_size, clamping_mode)
_, out_boxes = _order_bounding_boxes_points(out_boxes, indices)
_, original_boxes = _order_bounding_boxes_points(original_boxes, indices)
# rotate 90 degrees counter clock wise
out_boxes[:, ::2], out_boxes[:, 1::2] = (
out_boxes[:, 1::2].clone(),
canvas_size[1] - out_boxes[:, ::2].clone(),
)
original_boxes[:, ::2], original_boxes[:, 1::2] = (
original_boxes[:, 1::2].clone(),
canvas_size[1] - original_boxes[:, ::2].clone(),
)
canvas_size = (canvas_size[1], canvas_size[0])

out_boxes = convert_bounding_box_format(
Expand All @@ -527,7 +610,8 @@ def _clamp_rotated_bounding_boxes(

if need_cast:
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
out_boxes.round_()
# Adding epsilon to ensure consistency between CPU and GPU rounding.
out_boxes.add_(1e-7).round_()
out_boxes = out_boxes.to(dtype)
return out_boxes

Expand Down
Loading