Skip to content

Commit 7e27337

Browse files
committed
explicitely cast coordinates to float to allow vectorization
1 parent 618bbe1 commit 7e27337

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,30 +61,26 @@ at::Tensor qnms_kernel_impl(
6161
continue;
6262
keep[num_to_keep++] = i;
6363

64-
auto ix1val = x1[i].val_;
65-
auto iy1val = y1[i].val_;
66-
auto ix2val = x2[i].val_;
67-
auto iy2val = y2[i].val_;
68-
auto iarea = areas[i];
64+
// We explicitely cast coordinates to float so that the code can be vectorized.
65+
float ix1val = x1[i].val_;
66+
float iy1val = y1[i].val_;
67+
float ix2val = x2[i].val_;
68+
float iy2val = y2[i].val_;
69+
float iarea = areas[i];
6970

7071
for (int64_t _j = _i + 1; _j < ndets; _j++) {
7172
auto j = order[_j];
7273
if (suppressed[j] == 1)
7374
continue;
74-
auto xx1 = std::max(ix1val, x1[j].val_);
75-
auto yy1 = std::max(iy1val, y1[j].val_);
76-
auto xx2 = std::min(ix2val, x2[j].val_);
77-
auto yy2 = std::min(iy2val, y2[j].val_);
78-
79-
// This may underflow if xx2 < xx1 on unsigned types but as noted above,
80-
// integral promotion should prevent it. Also, an actual underflow would
81-
// lead to a negative ovr (because of high value for inter), but since the
82-
// actual over should have been 0 the condition below isn't altered, and
83-
// thus the underflow should be effectively harmless.
84-
auto w = std::max(0, xx2 - xx1); // * scale (gets canceled below)
85-
auto h = std::max(0, yy2 - yy1); // * scale (gets canceled below)
75+
float xx1 = std::max(ix1val, (float)x1[j].val_);
76+
float yy1 = std::max(iy1val, (float)y1[j].val_);
77+
float xx2 = std::min(ix2val, (float)x2[j].val_);
78+
float yy2 = std::min(iy2val, (float)y2[j].val_);
79+
80+
auto w = std::max(0.f, xx2 - xx1); // * scale (gets canceled below)
81+
auto h = std::max(0.f, yy2 - yy1); // * scale (gets canceled below)
8682
auto inter = w * h;
87-
auto ovr = (float)inter / (iarea + areas[j] - inter);
83+
auto ovr = inter / (iarea + areas[j] - inter);
8884
if (ovr > iou_threshold)
8985
suppressed[j] = 1;
9086
}

0 commit comments

Comments
 (0)