Skip to content

Commit f74bfab

Browse files
authored
Add quantized version of nms (#3601)
* Add quantized version of nms * Added tests * Compute areas only once * remove calls to dequantize_val * fix return type for empty tensor * flake8 * remove use of scale as it gets cancelled out * simpler int convertion in tests * explicitly set ovr to double * add tests for more values of scale and zero_point * comment about underflow * remove unnecessary accessor * properly convert to float for division * Add comments about underflow * explicitely cast coordinates to float to allow vectorization * clang * clang again * hopefully OK now
1 parent 978ba61 commit f74bfab

File tree

3 files changed

+157
-2
lines changed

3 files changed

+157
-2
lines changed

setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,11 @@ def get_extensions():
138138

139139
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
140140
'*.cpp'))
141-
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob(
142-
os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
141+
source_cpu = (
142+
glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) +
143+
glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) +
144+
glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp'))
145+
)
143146

144147
is_rocm_pytorch = False
145148
if torch.__version__ >= '1.5':

test/test_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,29 @@ def test_nms(self):
418418
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5)
419419
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5)
420420

421+
def test_qnms(self):
422+
# Note: we compare qnms vs nms instead of qnms vs reference implementation.
423+
# This is because with the int convertion, the trick used in _create_tensors_with_iou
424+
# doesn't really work (in fact, nms vs reference implem will also fail with ints)
425+
err_msg = 'NMS and QNMS give different results for IoU={}'
426+
for iou in [0.2, 0.5, 0.8]:
427+
for scale, zero_point in ((1, 0), (2, 50), (3, 10)):
428+
boxes, scores = self._create_tensors_with_iou(1000, iou)
429+
scores *= 100 # otherwise most scores would be 0 or 1 after int convertion
430+
431+
qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point,
432+
dtype=torch.quint8)
433+
qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point,
434+
dtype=torch.quint8)
435+
436+
boxes = qboxes.dequantize()
437+
scores = qscores.dequantize()
438+
439+
keep = ops.nms(boxes, scores, iou)
440+
qkeep = ops.nms(qboxes, qscores, iou)
441+
442+
self.assertTrue(torch.allclose(qkeep, keep), err_msg.format(iou))
443+
421444
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
422445
def test_nms_cuda(self, dtype=torch.float64):
423446
tol = 1e-3 if dtype is torch.half else 1e-5
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/native/quantized/affine_quantizer.h>
3+
#include <torch/library.h>
4+
5+
namespace vision {
6+
namespace ops {
7+
8+
namespace {
9+
10+
template <typename scalar_t>
11+
at::Tensor qnms_kernel_impl(
12+
const at::Tensor& dets,
13+
const at::Tensor& scores,
14+
double iou_threshold) {
15+
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
16+
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
17+
TORCH_CHECK(
18+
dets.scalar_type() == scores.scalar_type(),
19+
"dets should have the same type as scores");
20+
21+
if (dets.numel() == 0)
22+
return at::empty({0}, dets.options().dtype(at::kLong));
23+
24+
const auto ndets = dets.size(0);
25+
26+
auto x1_t = dets.select(1, 0).contiguous();
27+
auto y1_t = dets.select(1, 1).contiguous();
28+
auto x2_t = dets.select(1, 2).contiguous();
29+
auto y2_t = dets.select(1, 3).contiguous();
30+
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
31+
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
32+
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
33+
at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat));
34+
35+
auto suppressed = suppressed_t.data_ptr<uint8_t>();
36+
auto keep = keep_t.data_ptr<int64_t>();
37+
auto order = order_t.data_ptr<int64_t>();
38+
auto x1 = x1_t.data_ptr<scalar_t>();
39+
auto y1 = y1_t.data_ptr<scalar_t>();
40+
auto x2 = x2_t.data_ptr<scalar_t>();
41+
auto y2 = y2_t.data_ptr<scalar_t>();
42+
auto areas = areas_t.data_ptr<float>();
43+
44+
for (int64_t i = 0; i < ndets; i++) {
45+
// Note 1: To get the exact area we'd need to multiply by scale**2, but this
46+
// would get canceled out in the computation of ovr below. So we leave that
47+
// out.
48+
// Note 2: degenerate boxes (x2 < x1 or y2 < y1) may underflow, although
49+
// integral promotion rules will likely prevent it (see
50+
// https://stackoverflow.com/questions/32959564/subtraction-of-two-unsigned-gives-signed
51+
// for more details).
52+
areas[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_);
53+
}
54+
55+
int64_t num_to_keep = 0;
56+
57+
for (int64_t _i = 0; _i < ndets; _i++) {
58+
auto i = order[_i];
59+
if (suppressed[i] == 1)
60+
continue;
61+
keep[num_to_keep++] = i;
62+
63+
// We explicitely cast coordinates to float so that the code can be
64+
// vectorized.
65+
float ix1val = x1[i].val_;
66+
float iy1val = y1[i].val_;
67+
float ix2val = x2[i].val_;
68+
float iy2val = y2[i].val_;
69+
float iarea = areas[i];
70+
71+
for (int64_t _j = _i + 1; _j < ndets; _j++) {
72+
auto j = order[_j];
73+
if (suppressed[j] == 1)
74+
continue;
75+
float xx1 = std::max(ix1val, (float)x1[j].val_);
76+
float yy1 = std::max(iy1val, (float)y1[j].val_);
77+
float xx2 = std::min(ix2val, (float)x2[j].val_);
78+
float yy2 = std::min(iy2val, (float)y2[j].val_);
79+
80+
auto w = std::max(0.f, xx2 - xx1); // * scale (gets canceled below)
81+
auto h = std::max(0.f, yy2 - yy1); // * scale (gets canceled below)
82+
auto inter = w * h;
83+
auto ovr = inter / (iarea + areas[j] - inter);
84+
if (ovr > iou_threshold)
85+
suppressed[j] = 1;
86+
}
87+
}
88+
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
89+
}
90+
91+
at::Tensor qnms_kernel(
92+
const at::Tensor& dets,
93+
const at::Tensor& scores,
94+
double iou_threshold) {
95+
TORCH_CHECK(
96+
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
97+
TORCH_CHECK(
98+
dets.size(1) == 4,
99+
"boxes should have 4 elements in dimension 1, got ",
100+
dets.size(1));
101+
TORCH_CHECK(
102+
scores.dim() == 1,
103+
"scores should be a 1d tensor, got ",
104+
scores.dim(),
105+
"D");
106+
TORCH_CHECK(
107+
dets.size(0) == scores.size(0),
108+
"boxes and scores should have same number of elements in ",
109+
"dimension 0, got ",
110+
dets.size(0),
111+
" and ",
112+
scores.size(0));
113+
114+
auto result = at::empty({0});
115+
116+
AT_DISPATCH_QINT_TYPES(dets.scalar_type(), "qnms_kernel", [&] {
117+
result = qnms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
118+
});
119+
return result;
120+
}
121+
122+
} // namespace
123+
124+
TORCH_LIBRARY_IMPL(torchvision, QuantizedCPU, m) {
125+
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(qnms_kernel));
126+
}
127+
128+
} // namespace ops
129+
} // namespace vision

0 commit comments

Comments
 (0)