Skip to content

Commit dd76ba9

Browse files
committed
Speed up postprocess_detections() by introducing a prefilter step.
1 parent 072d8b2 commit dd76ba9

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,13 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
439439
image_labels = []
440440
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})
441441

442-
for class_index in range(num_classes):
442+
# prefilter step: drop all boxes that dont pass the threshold in any class and keep only relevant classes
443+
inds, class_ids = (t.unique() for t in (scores_per_image > self.score_thresh).nonzero(as_tuple=True))
444+
boxes_per_image, scores_per_image, labels_per_image = \
445+
boxes_per_image[inds], scores_per_image[inds], labels_per_image[inds]
446+
other_outputs_per_image = [(k, v[inds]) for k, v in other_outputs_per_image]
447+
448+
for class_index in class_ids.tolist():
443449
# remove low scoring boxes
444450
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh)
445451
boxes_per_class, scores_per_class, labels_per_class = \

0 commit comments

Comments
 (0)