Skip to content

Commit 38bd83d

Browse files
oke-adityafmassa
authored andcommitted
Adds bounding boxes conversion (pytorch#2710)
* adds boxes conversion * adds documentation * adds xywh tests * fixes small typo * adds tests * Remove sphinx theme * corrects assertions * cleans code as per suggestion Signed-off-by: Aditya Oke <[email protected]> * reverts assertion * fixes to assertEqual * fixes inplace operations * Adds docstrings * added documentation * changes tests * moves code to box_convert * adds more tests * Apply suggestions from code review Let's leave those changes to a separate PR * fixes documentation Co-authored-by: Francisco Massa <[email protected]>
1 parent 045a23e commit 38bd83d

File tree

5 files changed

+239
-1
lines changed

5 files changed

+239
-1
lines changed

docs/source/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ torchvision.ops
1313
.. autofunction:: batched_nms
1414
.. autofunction:: remove_small_boxes
1515
.. autofunction:: clip_boxes_to_image
16+
.. autofunction:: box_convert
1617
.. autofunction:: box_area
1718
.. autofunction:: box_iou
1819
.. autofunction:: generalized_box_iou

test/test_ops.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,102 @@ def test_convert_boxes_to_roi_format(self):
647647
self.assertTrue(torch.equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)))
648648

649649

650+
class BoxTester(unittest.TestCase):
651+
def test_bbox_same(self):
652+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
653+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
654+
655+
exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
656+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
657+
658+
box_same = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy")
659+
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
660+
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
661+
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
662+
663+
box_same = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh")
664+
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
665+
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
666+
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
667+
668+
box_same = ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh")
669+
self.assertEqual(exp_xyxy.size(), torch.Size([4, 4]))
670+
self.assertEqual(exp_xyxy.dtype, box_tensor.dtype)
671+
assert torch.all(torch.eq(box_same, exp_xyxy)).item()
672+
673+
def test_bbox_xyxy_xywh(self):
674+
# Simple test convert boxes to xywh and back. Make sure they are same.
675+
# box_tensor is in x1 y1 x2 y2 format.
676+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
677+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
678+
exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
679+
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
680+
681+
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
682+
self.assertEqual(exp_xywh.size(), torch.Size([4, 4]))
683+
self.assertEqual(exp_xywh.dtype, box_tensor.dtype)
684+
assert torch.all(torch.eq(box_xywh, exp_xywh)).item()
685+
686+
# Reverse conversion
687+
box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
688+
self.assertEqual(box_xyxy.size(), torch.Size([4, 4]))
689+
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
690+
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
691+
692+
def test_bbox_xyxy_cxcywh(self):
693+
# Simple test convert boxes to xywh and back. Make sure they are same.
694+
# box_tensor is in x1 y1 x2 y2 format.
695+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
696+
[10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
697+
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
698+
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)
699+
700+
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
701+
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4]))
702+
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
703+
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
704+
705+
# Reverse conversion
706+
box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
707+
self.assertEqual(box_xyxy.size(), torch.Size([4, 4]))
708+
self.assertEqual(box_xyxy.dtype, box_tensor.dtype)
709+
assert torch.all(torch.eq(box_xyxy, box_tensor)).item()
710+
711+
def test_bbox_xywh_cxcywh(self):
712+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
713+
[10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
714+
715+
# This is wrong
716+
exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0],
717+
[20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float)
718+
719+
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
720+
self.assertEqual(exp_cxcywh.size(), torch.Size([4, 4]))
721+
self.assertEqual(exp_cxcywh.dtype, box_tensor.dtype)
722+
assert torch.all(torch.eq(box_cxcywh, exp_cxcywh)).item()
723+
724+
# Reverse conversion
725+
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
726+
self.assertEqual(box_xywh.size(), torch.Size([4, 4]))
727+
self.assertEqual(box_xywh.dtype, box_tensor.dtype)
728+
assert torch.all(torch.eq(box_xywh, box_tensor)).item()
729+
730+
# def test_bbox_convert_jit(self):
731+
# box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0],
732+
# [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
733+
734+
# scripted_fn = torch.jit.script(ops.box_convert)
735+
# TOLERANCE = 1e-3
736+
737+
# box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
738+
# scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh')
739+
# self.assertTrue((scripted_xywh - box_xywh).abs().max() < TOLERANCE)
740+
741+
# box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
742+
# scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh')
743+
# self.assertTrue((scripted_cxcywh - box_cxcywh).abs().max() < TOLERANCE)
744+
745+
650746
class BoxAreaTester(unittest.TestCase):
651747
def test_box_area(self):
652748
# A bounding box of area 10000 and a degenerate case

torchvision/ops/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou
2+
from .boxes import box_convert
23
from .new_empty_tensor import _new_empty_tensor
34
from .deform_conv import deform_conv2d, DeformConv2d
45
from .roi_align import roi_align, RoIAlign
@@ -15,7 +16,8 @@
1516

1617
__all__ = [
1718
'deform_conv2d', 'DeformConv2d', 'nms', 'batched_nms', 'remove_small_boxes',
18-
'clip_boxes_to_image', 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
19+
'clip_boxes_to_image', 'box_convert',
20+
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool',
1921
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
2022
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
2123
]

torchvision/ops/_box_convert.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
from torch.jit.annotations import Tuple
3+
from torch import Tensor
4+
import torchvision
5+
6+
7+
def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
8+
"""
9+
Converts bounding boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format.
10+
(cx, cy) refers to center of bounding box
11+
(w, h) are width and height of bounding box
12+
Arguments:
13+
boxes (Tensor[N, 4]): boxes in (cx, cy, w, h) format which will be converted.
14+
15+
Returns:
16+
boxes (Tensor(N, 4)): boxes in (x1, y1, x2, y2) format.
17+
"""
18+
# We need to change all 4 of them so some temporary variable is needed.
19+
cx, cy, w, h = boxes.unbind(-1)
20+
x1 = cx - 0.5 * w
21+
y1 = cy - 0.5 * h
22+
x2 = cx + 0.5 * w
23+
y2 = cy + 0.5 * h
24+
25+
boxes = torch.stack((x1, y1, x2, y2), dim=-1)
26+
27+
return boxes
28+
29+
30+
def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
31+
"""
32+
Converts bounding boxes from (x1, y1, x2, y2) format to (cx, cy, w, h) format.
33+
(x1, y1) refer to top left of bounding box
34+
(x2, y2) refer to bottom right of bounding box
35+
Arguments:
36+
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format which will be converted.
37+
38+
Returns:
39+
boxes (Tensor(N, 4)): boxes in (cx, cy, w, h) format.
40+
"""
41+
x1, y1, x2, y2 = boxes.unbind(-1)
42+
cx = (x1 + x2) / 2
43+
cy = (y1 + y2) / 2
44+
w = x2 - x1
45+
h = y2 - y1
46+
47+
boxes = torch.stack((cx, cy, w, h), dim=-1)
48+
49+
return boxes
50+
51+
52+
def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor:
53+
"""
54+
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
55+
(x, y) refers to top left of bouding box.
56+
(w, h) refers to width and height of box.
57+
Arguments:
58+
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
59+
60+
Returns:
61+
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format.
62+
"""
63+
x, y, w, h = boxes.unbind(-1)
64+
boxes = torch.stack([x, y, x + w, y + h], dim=-1)
65+
return boxes
66+
67+
68+
def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor:
69+
"""
70+
Converts bounding boxes from (x1, y1, x2, y2) format to (x, y, w, h) format.
71+
(x1, y1) refer to top left of bounding box
72+
(x2, y2) refer to bottom right of bounding box
73+
Arguments:
74+
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) which will be converted.
75+
76+
Returns:
77+
boxes (Tensor[N, 4]): boxes in (x, y, w, h) format.
78+
"""
79+
x1, y1, x2, y2 = boxes.unbind(-1)
80+
x2 = x2 - x1 # x2 - x1
81+
y2 = y2 - y1 # y2 - y1
82+
boxes = torch.stack((x1, y1, x2, y2), dim=-1)
83+
return boxes

torchvision/ops/boxes.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch.jit.annotations import Tuple
33
from torch import Tensor
4+
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
45
import torchvision
56

67

@@ -133,6 +134,61 @@ def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
133134
return clipped_boxes.reshape(boxes.shape)
134135

135136

137+
def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
138+
"""
139+
Converts boxes from given in_fmt to out_fmt.
140+
Supported in_fmt and out_fmt are:
141+
142+
'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
143+
144+
'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
145+
146+
'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
147+
being width and height.
148+
149+
Arguments:
150+
boxes (Tensor[N, 4]): boxes which will be converted.
151+
in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
152+
out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
153+
154+
Returns:
155+
boxes (Tensor[N, 4]): Boxes into converted format.
156+
"""
157+
allowed_fmts = ("xyxy", "xywh", "cxcywh")
158+
assert in_fmt in allowed_fmts
159+
assert out_fmt in allowed_fmts
160+
161+
if in_fmt == out_fmt:
162+
boxes_converted = boxes.clone()
163+
return boxes_converted
164+
165+
if in_fmt != 'xyxy' and out_fmt != 'xyxy':
166+
if in_fmt == "xywh":
167+
boxes_xyxy = _box_xywh_to_xyxy(boxes)
168+
if out_fmt == "cxcywh":
169+
boxes_converted = _box_xyxy_to_cxcywh(boxes_xyxy)
170+
171+
elif in_fmt == "cxcywh":
172+
boxes_xyxy = _box_cxcywh_to_xyxy(boxes)
173+
if out_fmt == "xywh":
174+
boxes_converted = _box_xyxy_to_xywh(boxes_xyxy)
175+
176+
# convert one to xyxy and change either in_fmt or out_fmt to xyxy
177+
else:
178+
if in_fmt == "xyxy":
179+
if out_fmt == "xywh":
180+
boxes_converted = _box_xyxy_to_xywh(boxes)
181+
elif out_fmt == "cxcywh":
182+
boxes_converted = _box_xyxy_to_cxcywh(boxes)
183+
elif out_fmt == "xyxy":
184+
if in_fmt == "xywh":
185+
boxes_converted = _box_xywh_to_xyxy(boxes)
186+
elif in_fmt == "cxcywh":
187+
boxes_converted = _box_cxcywh_to_xyxy(boxes)
188+
189+
return boxes_converted
190+
191+
136192
def box_area(boxes: Tensor) -> Tensor:
137193
"""
138194
Computes the area of a set of bounding boxes, which are specified by its

0 commit comments

Comments
 (0)