Skip to content

Commit bb779bb

Browse files
zhangguanheng66Guanheng Zhang
authored andcommitted
Check boxes shape in RoIPool / Align (#1968) (#2429)
Summary: * add checkout/assert in roi_pool * add checkout/assert in roi_align * move check_roi_boxes_shape func to ops/_utils.py * add tests * fix CI * fix CI Pull Request resolved: #2429 Reviewed By: zhangguanheng66 Differential Revision: D22437763 Pulled By: fmassa fbshipit-source-id: 78727f3bfe2514e2c193e2b27d9146693fa800b0 Co-authored-by: Guanheng Zhang <[email protected]>
1 parent 23880d4 commit bb779bb

File tree

6 files changed

+48
-4
lines changed

6 files changed

+48
-4
lines changed

test/test_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def func(z):
9191
self.assertTrue(gradcheck(func, (x,)))
9292
self.assertTrue(gradcheck(script_func, (x,)))
9393

94+
def test_boxes_shape(self):
95+
self._test_boxes_shape()
96+
97+
def _helper_boxes_shape(self, func):
98+
# test boxes as Tensor[N, 5]
99+
with self.assertRaises(AssertionError):
100+
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
101+
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
102+
func(a, boxes, output_size=(2, 2))
103+
104+
# test boxes as List[Tensor[N, 4]]
105+
with self.assertRaises(AssertionError):
106+
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
107+
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
108+
ops.roi_pool(a, [boxes], output_size=(2, 2))
109+
94110
def fn(*args, **kwargs):
95111
pass
96112

@@ -139,6 +155,9 @@ def get_slice(k, block):
139155
y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
140156
return y
141157

158+
def _test_boxes_shape(self):
159+
self._helper_boxes_shape(ops.roi_pool)
160+
142161

143162
class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
144163
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
@@ -183,6 +202,9 @@ def get_slice(k, block):
183202
y[roi_idx, c_out, i, j] = t / area
184203
return y
185204

205+
def _test_boxes_shape(self):
206+
self._helper_boxes_shape(ops.ps_roi_pool)
207+
186208

187209
def bilinear_interpolate(data, y, x, snap_border=False):
188210
height, width = data.shape
@@ -266,6 +288,9 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
266288
out_data[r, channel, i, j] = val
267289
return out_data
268290

291+
def _test_boxes_shape(self):
292+
self._helper_boxes_shape(ops.roi_align)
293+
269294

270295
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
271296
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
@@ -317,6 +342,9 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
317342
out_data[r, c_out, i, j] = val
318343
return out_data
319344

345+
def _test_boxes_shape(self):
346+
self._helper_boxes_shape(ops.ps_roi_align)
347+
320348

321349
class NMSTester(unittest.TestCase):
322350
def reference_nms(self, boxes, scores, iou_threshold):

torchvision/ops/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,15 @@ def convert_boxes_to_roi_format(boxes):
2424
ids = _cat(temp, dim=0)
2525
rois = torch.cat([ids, concat_boxes], dim=1)
2626
return rois
27+
28+
29+
def check_roi_boxes_shape(boxes):
30+
if isinstance(boxes, list):
31+
for _tensor in boxes:
32+
assert _tensor.size(1) == 4, \
33+
'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'
34+
elif isinstance(boxes, torch.Tensor):
35+
assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]'
36+
else:
37+
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
38+
return

torchvision/ops/ps_roi_align.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List
66

7-
from ._utils import convert_boxes_to_roi_format
7+
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
88

99

1010
def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
@@ -33,6 +33,7 @@ def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1
3333
Returns:
3434
output (Tensor[K, C, output_size[0], output_size[1]])
3535
"""
36+
check_roi_boxes_shape(boxes)
3637
rois = boxes
3738
output_size = _pair(output_size)
3839
if not isinstance(rois, torch.Tensor):

torchvision/ops/ps_roi_pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List
66

7-
from ._utils import convert_boxes_to_roi_format
7+
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
88

99

1010
def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
@@ -28,6 +28,7 @@ def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
2828
Returns:
2929
output (Tensor[K, C, output_size[0], output_size[1]])
3030
"""
31+
check_roi_boxes_shape(boxes)
3132
rois = boxes
3233
output_size = _pair(output_size)
3334
if not isinstance(rois, torch.Tensor):

torchvision/ops/roi_align.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List, BroadcastingList2
66

7-
from ._utils import convert_boxes_to_roi_format
7+
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
88

99

1010
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False):
@@ -35,6 +35,7 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, a
3535
Returns:
3636
output (Tensor[K, C, output_size[0], output_size[1]])
3737
"""
38+
check_roi_boxes_shape(boxes)
3839
rois = boxes
3940
output_size = _pair(output_size)
4041
if not isinstance(rois, torch.Tensor):

torchvision/ops/roi_pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.nn.modules.utils import _pair
55
from torch.jit.annotations import List, BroadcastingList2
66

7-
from ._utils import convert_boxes_to_roi_format
7+
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
88

99

1010
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
@@ -27,6 +27,7 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
2727
Returns:
2828
output (Tensor[K, C, output_size[0], output_size[1]])
2929
"""
30+
check_roi_boxes_shape(boxes)
3031
rois = boxes
3132
output_size = _pair(output_size)
3233
if not isinstance(rois, torch.Tensor):

0 commit comments

Comments
 (0)