Skip to content

Commit e61538c

Browse files
authored
improve stability of test_nms_cuda (#2044)
* improve stability of test_nms_cuda This change addresses two issues: _create_tensors_with_iou() creates test data for the NMS tests. It takes care to ensure at least one pair of boxes (1st and last) have IoU around the threshold for the test. However, the constructed IoU for that pair is _so_ close to the threshold that rounding differences (presumably) between CPU and CUDA implementations may result in one suppressing a box in the pair and the other not. Adjust the construction to ensure the IoU for the box pair is near the threshold, but far-enough above that both implementations should agree. Where 2 boxes have nearly or exactly the same score, the CPU and CUDA implementations may order them differently. Adjust test_nms_cuda() to check only that the non-suppressed box lists include the same members, without regard for ordering. * adjust assertion in test_nms_cuda The CPU and CUDA nms implementations each sort the box scores as part of their work, but the sorts they use are not stable. So boxes with the same score maybe be processed in opposite order by the two implmentations. Relax the assertion in test_nms_cuda (following the model in pytorch's test_topk()) to allow the test to pass if the output differences are caused by similarly-scored boxes. * improve stability of test_nms_cuda Adjust _create_tensors_with_iou() to ensure we create at least one box just over threshold that should be suppressed.
1 parent 9ed2fa3 commit e61538c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

test/test_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,14 @@ def _create_tensors_with_iou(self, N, iou_thresh):
374374
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
375375
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
376376
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
377+
# Adjust the threshold upward a bit with the intent of creating
378+
# at least one box that exceeds (barely) the threshold and so
379+
# should be suppressed.
377380
boxes = torch.rand(N, 4) * 100
378381
boxes[:, 2:] += boxes[:, :2]
379382
boxes[-1, :] = boxes[0, :]
380383
x0, y0, x1, y1 = boxes[-1].tolist()
384+
iou_thresh += 1e-5
381385
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
382386
scores = torch.rand(N)
383387
return boxes, scores
@@ -399,7 +403,12 @@ def test_nms_cuda(self):
399403
r_cpu = ops.nms(boxes, scores, iou)
400404
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
401405

402-
self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
406+
is_eq = torch.allclose(r_cpu, r_cuda.cpu())
407+
if not is_eq:
408+
# if the indices are not the same, ensure that it's because the scores
409+
# are duplicate
410+
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()])
411+
self.assertTrue(is_eq, err_msg.format(iou))
403412

404413

405414
class NewEmptyTensorTester(unittest.TestCase):

0 commit comments

Comments
 (0)