Skip to content

Commit 264d4d7

Browse files
committed
introduce thresholds on init to fix onnx
1 parent 64b33a9 commit 264d4d7

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

test/test_onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _init_test_generalized_rcnn_transform(self):
184184
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
185185
return transform
186186

187-
def _init_test_rpn(self):
187+
def _init_test_rpn(self, score_threshold=0.0):
188188
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
189189
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
190190
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
@@ -197,7 +197,7 @@ def _init_test_rpn(self):
197197
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
198198
rpn_post_nms_top_n = dict(training=2000, testing=1000)
199199
rpn_nms_thresh = 0.7
200-
rpn_score_thresh = 0.0
200+
rpn_score_thresh = score_threshold
201201

202202
rpn = RegionProposalNetwork(
203203
rpn_anchor_generator, rpn_head,
@@ -260,7 +260,7 @@ def test_rpn(self):
260260
class RPNModule(torch.nn.Module):
261261
def __init__(self_module):
262262
super(RPNModule, self_module).__init__()
263-
self_module.rpn = self._init_test_rpn()
263+
self_module.rpn = self._init_test_rpn(0.5)
264264

265265
def forward(self_module, images, features):
266266
images = ImageList(images, [i.shape[-2:] for i in images])

torchvision/models/detection/rpn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
264264
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
265265

266266
# remove low scoring boxes
267-
keep = torch.where(scores > self.score_thresh)[0]
267+
# use >= for Backwards compatibility
268+
keep = torch.where(scores >= self.score_thresh)[0]
268269
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
269270

270271
# non-maximum suppression, independently done per level

0 commit comments

Comments
 (0)