Skip to content

Commit c5d2868

Browse files
committed
Rewriting losses to remove branching.
1 parent e8d5822 commit c5d2868

File tree

1 file changed

+12
-26
lines changed

1 file changed

+12
-26
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
107107
# determine only the foreground
108108
foreground_idxs_per_image = matched_idxs_per_image >= 0
109109
num_foreground = foreground_idxs_per_image.sum()
110-
# no matched_idxs means there were no annotations in this image
111-
if matched_idxs_per_image.numel() == 0:
112-
gt_classes_target = torch.zeros_like(cls_logits_per_image)
113-
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0], device=cls_logits_per_image.device)
114-
else:
115-
# create the target classification
116-
gt_classes_target = torch.zeros_like(cls_logits_per_image)
117-
gt_classes_target[
118-
foreground_idxs_per_image,
119-
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
120-
] = 1.0
121-
122-
# find indices for which anchors should be ignored
123-
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
110+
111+
# create the target classification
112+
gt_classes_target = torch.zeros_like(cls_logits_per_image)
113+
gt_classes_target[
114+
foreground_idxs_per_image,
115+
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
116+
] = 1.0
117+
118+
# find indices for which anchors should be ignored
119+
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
124120

125121
# compute the classification loss
126122
losses.append(sigmoid_focal_loss(
@@ -190,22 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
190186

191187
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
192188
zip(targets, bbox_regression, anchors, matched_idxs):
193-
# no matched_idxs means there were no annotations in this image
194-
if matched_idxs_per_image.numel() == 0:
195-
matched_gt_boxes_per_image = torch.zeros_like(bbox_regression_per_image)
196-
else:
197-
# get the targets corresponding GT for each proposal
198-
# NB: need to clamp the indices because we can have a single
199-
# GT in the image, and matched_idxs can be -2, which goes
200-
# out of bounds
201-
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]
202-
203189
# determine only the foreground indices, ignore the rest
204190
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
205191
num_foreground = foreground_idxs_per_image.numel()
206192

207193
# select only the foreground boxes
208-
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
194+
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
209195
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
210196
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
211197

@@ -401,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
401387
matched_idxs = []
402388
for anchors_per_image, targets_per_image in zip(anchors, targets):
403389
if targets_per_image['boxes'].numel() == 0:
404-
matched_idxs.append(torch.empty((0,), dtype=torch.int64))
390+
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64))
405391
continue
406392

407393
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)

0 commit comments

Comments
 (0)