Skip to content

Commit 0467c9d

Browse files
datumboxfmassa
andauthored
Vectorize RetinaNet's postprocessing (#2828)
* Vectorize operations, across all feaure levels. * Remove unnecessary other_outputs variable. * Split per feature level. * Perform batched_nms across feature levels. * Add extra parameter for limiting detections before and after nms. * Restoring default threshold. * Apply suggestions from code review Co-authored-by: Francisco Massa <[email protected]> * Renaming variable. Co-authored-by: Francisco Massa <[email protected]>
1 parent d94b45d commit 0467c9d

File tree

3 files changed

+59
-54
lines changed

3 files changed

+59
-54
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ htmlcov
2222
gen.yml
2323
.mypy_cache
2424
.vscode/
25+
.idea/
2526
*.orig
2627
*-checkpoint.ipynb
Binary file not shown.

torchvision/models/detection/retinanet.py

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ class RetinaNet(nn.Module):
291291
considered as positive during training.
292292
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
293293
considered as negative during training.
294+
topk_candidates (int): Number of best detections to keep before NMS.
294295
295296
Example:
296297
@@ -339,7 +340,8 @@ def __init__(self, backbone, num_classes,
339340
score_thresh=0.05,
340341
nms_thresh=0.5,
341342
detections_per_img=300,
342-
fg_iou_thresh=0.5, bg_iou_thresh=0.4):
343+
fg_iou_thresh=0.5, bg_iou_thresh=0.4,
344+
topk_candidates=1000):
343345
super().__init__()
344346

345347
if not hasattr(backbone, "out_channels"):
@@ -382,6 +384,7 @@ def __init__(self, backbone, num_classes,
382384
self.score_thresh = score_thresh
383385
self.nms_thresh = nms_thresh
384386
self.detections_per_img = detections_per_img
387+
self.topk_candidates = topk_candidates
385388

386389
# used only on torchscript mode
387390
self._has_warned = False
@@ -408,77 +411,63 @@ def compute_loss(self, targets, head_outputs, anchors):
408411
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
409412

410413
def postprocess_detections(self, head_outputs, anchors, image_shapes):
411-
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412-
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
414+
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
415+
class_logits = head_outputs['cls_logits']
416+
box_regression = head_outputs['bbox_regression']
413417

414-
class_logits = head_outputs.pop('cls_logits')
415-
box_regression = head_outputs.pop('bbox_regression')
416-
other_outputs = head_outputs
417-
418-
device = class_logits.device
419-
num_classes = class_logits.shape[-1]
420-
421-
scores = torch.sigmoid(class_logits)
422-
423-
# create labels for each score
424-
labels = torch.arange(num_classes, device=device)
425-
labels = labels.view(1, -1).expand_as(scores)
418+
num_images = len(image_shapes)
426419

427420
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
428421

429-
for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
430-
enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):
431-
432-
boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image)
433-
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
434-
435-
other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()]
422+
for index in range(num_images):
423+
box_regression_per_image = [br[index] for br in box_regression]
424+
logits_per_image = [cl[index] for cl in class_logits]
425+
anchors_per_image, image_shape = anchors[index], image_shapes[index]
436426

437427
image_boxes = []
438428
image_scores = []
439429
image_labels = []
440-
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})
441430

442-
for class_index in range(num_classes):
431+
for box_regression_per_level, logits_per_level, anchors_per_level in \
432+
zip(box_regression_per_image, logits_per_image, anchors_per_image):
433+
num_classes = logits_per_level.shape[-1]
434+
443435
# remove low scoring boxes
444-
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh)
445-
boxes_per_class, scores_per_class, labels_per_class = \
446-
boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
447-
other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image]
436+
scores_per_level = torch.sigmoid(logits_per_level).flatten()
437+
keep_idxs = scores_per_level > self.score_thresh
438+
scores_per_level = scores_per_level[keep_idxs]
439+
topk_idxs = torch.where(keep_idxs)[0]
448440

449-
# remove empty boxes
450-
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2)
451-
boxes_per_class, scores_per_class, labels_per_class = \
452-
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
453-
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
441+
# keep only topk scoring predictions
442+
num_topk = min(self.topk_candidates, topk_idxs.size(0))
443+
scores_per_level, idxs = scores_per_level.topk(num_topk)
444+
topk_idxs = topk_idxs[idxs]
454445

455-
# non-maximum suppression, independently done per class
456-
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh)
446+
anchor_idxs = topk_idxs // num_classes
447+
labels_per_level = topk_idxs % num_classes
457448

458-
# keep only topk scoring predictions
459-
keep = keep[:self.detections_per_img]
460-
boxes_per_class, scores_per_class, labels_per_class = \
461-
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
462-
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
449+
boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs],
450+
anchors_per_level[anchor_idxs])
451+
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
452+
453+
image_boxes.append(boxes_per_level)
454+
image_scores.append(scores_per_level)
455+
image_labels.append(labels_per_level)
463456

464-
image_boxes.append(boxes_per_class)
465-
image_scores.append(scores_per_class)
466-
image_labels.append(labels_per_class)
457+
image_boxes = torch.cat(image_boxes, dim=0)
458+
image_scores = torch.cat(image_scores, dim=0)
459+
image_labels = torch.cat(image_labels, dim=0)
467460

468-
for k, v in other_outputs_per_class:
469-
if k not in image_other_outputs:
470-
image_other_outputs[k] = []
471-
image_other_outputs[k].append(v)
461+
# non-maximum suppression
462+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
463+
keep = keep[:self.detections_per_img]
472464

473465
detections.append({
474-
'boxes': torch.cat(image_boxes, dim=0),
475-
'scores': torch.cat(image_scores, dim=0),
476-
'labels': torch.cat(image_labels, dim=0),
466+
'boxes': image_boxes[keep],
467+
'scores': image_scores[keep],
468+
'labels': image_labels[keep],
477469
})
478470

479-
for k, v in image_other_outputs.items():
480-
detections[-1].update({k: torch.cat(v, dim=0)})
481-
482471
return detections
483472

484473
def forward(self, images, targets=None):
@@ -557,8 +546,23 @@ def forward(self, images, targets=None):
557546
# compute the losses
558547
losses = self.compute_loss(targets, head_outputs, anchors)
559548
else:
549+
# recover level sizes
550+
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
551+
HW = 0
552+
for v in num_anchors_per_level:
553+
HW += v
554+
HWA = head_outputs['cls_logits'].size(1)
555+
A = HWA // HW
556+
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
557+
558+
# split outputs per level
559+
split_head_outputs: Dict[str, List[Tensor]] = {}
560+
for k in head_outputs:
561+
split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
562+
split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
563+
560564
# compute the detections
561-
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
565+
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
562566
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
563567

564568
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)