Skip to content

Commit 34810c0

Browse files
authored
Add more tests to NMS (#2279)
* Add more tests to NMS * Fix lint
1 parent b40f49f commit 34810c0

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

test/test_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ def test_nms(self):
393393
keep_ref = self.reference_nms(boxes, scores, iou)
394394
keep = ops.nms(boxes, scores, iou)
395395
self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))
396+
self.assertRaises(RuntimeError, ops.nms, torch.rand(4), torch.rand(3), 0.5)
397+
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 5), torch.rand(3), 0.5)
398+
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5)
399+
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5)
396400

397401
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
398402
def test_nms_cuda(self):

torchvision/csrc/nms.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,24 @@ at::Tensor nms(
1212
const at::Tensor& dets,
1313
const at::Tensor& scores,
1414
const double iou_threshold) {
15+
TORCH_CHECK(
16+
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
17+
TORCH_CHECK(
18+
dets.size(1) == 4,
19+
"boxes should have 4 elements in dimension 1, got ",
20+
dets.size(1));
21+
TORCH_CHECK(
22+
scores.dim() == 1,
23+
"scores should be a 1d tensor, got ",
24+
scores.dim(),
25+
"D");
26+
TORCH_CHECK(
27+
dets.size(0) == scores.size(0),
28+
"boxes and scores should have same number of elements in ",
29+
"dimension 0, got ",
30+
dets.size(0),
31+
" and ",
32+
scores.size(0));
1533
if (dets.is_cuda()) {
1634
#if defined(WITH_CUDA)
1735
if (dets.numel() == 0) {

0 commit comments

Comments
 (0)