Skip to content

Commit f131bbe

Browse files
committed
Split per feature level.
1 parent 1001612 commit f131bbe

File tree

2 files changed

+55
-27
lines changed

2 files changed

+55
-27
lines changed
Binary file not shown.

torchvision/models/detection/retinanet.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -408,44 +408,56 @@ def compute_loss(self, targets, head_outputs, anchors):
408408
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
409409

410410
def postprocess_detections(self, head_outputs, anchors, image_shapes):
411-
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
412-
# TODO: confirm that RetinaNet can't have other outputs like masks
411+
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
413412
class_logits = head_outputs['cls_logits']
414413
box_regression = head_outputs['bbox_regression']
415414

416-
num_classes = class_logits.shape[-1]
417-
418-
scores = torch.sigmoid(class_logits)
415+
num_images = len(image_shapes)
419416

420417
detections = torch.jit.annotate(List[Dict[str, Tensor]], [])
421418

422-
for index, (box_regression_per_image, scores_per_image, anchors_per_image, image_shape) in \
423-
enumerate(zip(box_regression, scores, anchors, image_shapes)):
424-
# remove low scoring boxes
425-
scores_per_image = scores_per_image.flatten()
426-
keep_idxs = scores_per_image > self.score_thresh
427-
scores_per_image = scores_per_image[keep_idxs]
428-
topk_idxs = torch.where(keep_idxs)[0]
419+
for index in range(num_images):
420+
box_regression_per_image = [br[index] for br in box_regression]
421+
logits_per_image = [cl[index] for cl in class_logits]
422+
anchors_per_image, image_shape = anchors[index], image_shapes[index]
423+
424+
image_boxes = []
425+
image_scores = []
426+
image_labels = []
427+
428+
for box_regression_per_level, logits_per_level, anchors_per_level in \
429+
zip(box_regression_per_image, logits_per_image, anchors_per_image):
430+
num_classes = logits_per_level.shape[-1]
429431

430-
# keep only topk scoring predictions
431-
num_topk = min(self.detections_per_img, topk_idxs.size(0))
432-
scores_per_image, idxs = scores_per_image.topk(num_topk)
433-
topk_idxs = topk_idxs[idxs]
432+
# remove low scoring boxes
433+
scores_per_level = torch.sigmoid(logits_per_level).flatten()
434+
keep_idxs = scores_per_level > self.score_thresh
435+
scores_per_level = scores_per_level[keep_idxs]
436+
topk_idxs = torch.where(keep_idxs)[0]
434437

435-
anchor_idxs = topk_idxs // num_classes
436-
labels_per_image = topk_idxs % num_classes
438+
# keep only topk scoring predictions
439+
num_topk = min(self.detections_per_img, topk_idxs.size(0))
440+
scores_per_level, idxs = scores_per_level.topk(num_topk)
441+
topk_idxs = topk_idxs[idxs]
437442

438-
boxes_per_image = self.box_coder.decode_single(box_regression_per_image[anchor_idxs],
439-
anchors_per_image[anchor_idxs])
440-
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
443+
anchor_idxs = topk_idxs // num_classes
444+
labels_per_level = topk_idxs % num_classes
441445

442-
# non-maximum suppression
443-
keep = box_ops.batched_nms(boxes_per_image, scores_per_image, labels_per_image, self.nms_thresh)
446+
boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs],
447+
anchors_per_level[anchor_idxs])
448+
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
449+
450+
# non-maximum suppression
451+
keep = box_ops.batched_nms(boxes_per_level, scores_per_level, labels_per_level, self.nms_thresh)
452+
453+
image_boxes.append(boxes_per_level[keep])
454+
image_scores.append(scores_per_level[keep])
455+
image_labels.append(labels_per_level[keep])
444456

445457
detections.append({
446-
'boxes': boxes_per_image[keep],
447-
'scores': scores_per_image[keep],
448-
'labels': labels_per_image[keep],
458+
'boxes': torch.cat(image_boxes, dim=0),
459+
'scores': torch.cat(image_scores, dim=0),
460+
'labels': torch.cat(image_labels, dim=0),
449461
})
450462

451463
return detections
@@ -526,8 +538,24 @@ def forward(self, images, targets=None):
526538
# compute the losses
527539
losses = self.compute_loss(targets, head_outputs, anchors)
528540
else:
541+
# recover level sizes
542+
feature_sizes_per_level = [x.size(2) * x.size(3) for x in features]
543+
HW = 0
544+
for v in feature_sizes_per_level:
545+
HW += v
546+
HWA = head_outputs['cls_logits'].size(1)
547+
A = HWA // HW
548+
feature_sizes_per_level = [hw * A for hw in feature_sizes_per_level]
549+
550+
# split outputs per level
551+
split_head_outputs: Dict[str, List[Tensor]] = {}
552+
for k in head_outputs:
553+
split_head_outputs[k] = [x.permute(1, 0, 2) for x in
554+
head_outputs[k].permute(1, 0, 2).split_with_sizes(feature_sizes_per_level)]
555+
split_anchors = [list(a.split_with_sizes(feature_sizes_per_level)) for a in anchors]
556+
529557
# compute the detections
530-
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
558+
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
531559
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
532560

533561
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)