Skip to content

Commit a3491d2

Browse files
committed
Vectorize operations, across all feaure levels.
1 parent b480903 commit a3491d2

File tree

3 files changed

+30
-55
lines changed

3 files changed

+30
-55
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ htmlcov
2222
gen.yml
2323
.mypy_cache
2424
.vscode/
25+
.idea/
2526
*.orig
2627
*-checkpoint.ipynb
Binary file not shown.

torchvision/models/detection/retinanet.py

Lines changed: 29 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -410,74 +410,48 @@ def compute_loss(self, targets, head_outputs, anchors):
410410
def postprocess_detections(self, head_outputs, anchors, image_shapes):
411411
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412412
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
413-
414413
class_logits = head_outputs.pop('cls_logits')
415414
box_regression = head_outputs.pop('bbox_regression')
416415
other_outputs = head_outputs
417416

418-
device = class_logits.device
419417
num_classes = class_logits.shape[-1]
420418

421419
scores = torch.sigmoid(class_logits)
422420

423-
# create labels for each score
424-
labels = torch.arange(num_classes, device=device)
425-
labels = labels.view(1, -1).expand_as(scores)
426-
427421
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
428422

429-
for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
430-
enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):
423+
for index, (box_regression_per_image, scores_per_image, anchors_per_image, image_shape) in \
424+
enumerate(zip(box_regression, scores, anchors, image_shapes)):
425+
# remove low scoring boxes
426+
scores_per_image = scores_per_image.flatten()
427+
keep_idxs = scores_per_image > self.score_thresh
428+
scores_per_image = scores_per_image[keep_idxs]
429+
topk_idxs = torch.where(keep_idxs)[0]
430+
431+
# keep only topk scoring predictions
432+
num_topk = min(self.detections_per_img, topk_idxs.size(0))
433+
scores_per_image, idxs = scores_per_image.topk(num_topk)
434+
topk_idxs = topk_idxs[idxs]
431435

432-
boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image)
436+
anchor_idxs = topk_idxs // num_classes
437+
labels_per_image = topk_idxs % num_classes
438+
439+
boxes_per_image = self.box_coder.decode_single(box_regression_per_image[anchor_idxs],
440+
anchors_per_image[anchor_idxs])
433441
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
434442

435-
other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()]
436-
437-
image_boxes = []
438-
image_scores = []
439-
image_labels = []
440-
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})
441-
442-
for class_index in range(num_classes):
443-
# remove low scoring boxes
444-
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh)
445-
boxes_per_class, scores_per_class, labels_per_class = \
446-
boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
447-
other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image]
448-
449-
# remove empty boxes
450-
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2)
451-
boxes_per_class, scores_per_class, labels_per_class = \
452-
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
453-
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
454-
455-
# non-maximum suppression, independently done per class
456-
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh)
457-
458-
# keep only topk scoring predictions
459-
keep = keep[:self.detections_per_img]
460-
boxes_per_class, scores_per_class, labels_per_class = \
461-
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
462-
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
463-
464-
image_boxes.append(boxes_per_class)
465-
image_scores.append(scores_per_class)
466-
image_labels.append(labels_per_class)
467-
468-
for k, v in other_outputs_per_class:
469-
if k not in image_other_outputs:
470-
image_other_outputs[k] = []
471-
image_other_outputs[k].append(v)
472-
473-
detections.append({
474-
'boxes': torch.cat(image_boxes, dim=0),
475-
'scores': torch.cat(image_scores, dim=0),
476-
'labels': torch.cat(image_labels, dim=0),
477-
})
478-
479-
for k, v in image_other_outputs.items():
480-
detections[-1].update({k: torch.cat(v, dim=0)})
443+
# non-maximum suppression
444+
keep = box_ops.batched_nms(boxes_per_image, scores_per_image, labels_per_image, self.nms_thresh)
445+
446+
det = {
447+
'boxes': boxes_per_image[keep],
448+
'scores': scores_per_image[keep],
449+
'labels': labels_per_image[keep],
450+
}
451+
for k, v in other_outputs.items():
452+
det[k] = v[index][keep]
453+
454+
detections.append(det)
481455

482456
return detections
483457

0 commit comments

Comments
 (0)