Skip to content

Commit 4ab46e5

Browse files
authored
Support for image with no annotations in RetinaNet (#3032)
* Enable support for images without annotations * Ensuring gradient propagates to RegressionHead. * Rewriting losses to remove branching. * Fix the seed on DeformConv autocast test.
1 parent 9e71fda commit 4ab46e5

File tree

3 files changed

+25
-30
lines changed

3 files changed

+25
-30
lines changed

test/test_models_detection_negative_samples.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ def test_forward_negative_sample_krcnn(self):
128128
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
129129
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))
130130

131+
def test_forward_negative_sample_retinanet(self):
132+
model = torchvision.models.detection.retinanet_resnet50_fpn(
133+
num_classes=2, min_size=100, max_size=100)
134+
135+
images, targets = self._make_empty_sample()
136+
loss_dict = model(images, targets)
137+
138+
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))
139+
131140

132141
if __name__ == '__main__':
133142
unittest.main()

test/test_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from common_utils import set_rng_seed
12
import math
23
import unittest
34

@@ -655,6 +656,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
655656

656657
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
657658
def test_autocast(self):
659+
set_rng_seed(0)
658660
for dtype in (torch.float, torch.half):
659661
with torch.cuda.amp.autocast():
660662
self._test_forward(torch.device("cuda"), False, dtype=dtype)

torchvision/models/detection/retinanet.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
107107
# determine only the foreground
108108
foreground_idxs_per_image = matched_idxs_per_image >= 0
109109
num_foreground = foreground_idxs_per_image.sum()
110-
# no matched_idxs means there were no annotations in this image
111-
# TODO: enable support for images without annotations that works on distributed
112-
if False: # matched_idxs_per_image.numel() == 0:
113-
gt_classes_target = torch.zeros_like(cls_logits_per_image)
114-
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0])
115-
else:
116-
# create the target classification
117-
gt_classes_target = torch.zeros_like(cls_logits_per_image)
118-
gt_classes_target[
119-
foreground_idxs_per_image,
120-
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
121-
] = 1.0
122-
123-
# find indices for which anchors should be ignored
124-
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
110+
111+
# create the target classification
112+
gt_classes_target = torch.zeros_like(cls_logits_per_image)
113+
gt_classes_target[
114+
foreground_idxs_per_image,
115+
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
116+
] = 1.0
117+
118+
# find indices for which anchors should be ignored
119+
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
125120

126121
# compute the classification loss
127122
losses.append(sigmoid_focal_loss(
@@ -191,23 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
191186

192187
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
193188
zip(targets, bbox_regression, anchors, matched_idxs):
194-
# no matched_idxs means there were no annotations in this image
195-
# TODO enable support for images without annotations with distributed support
196-
# if matched_idxs_per_image.numel() == 0:
197-
# continue
198-
199-
# get the targets corresponding GT for each proposal
200-
# NB: need to clamp the indices because we can have a single
201-
# GT in the image, and matched_idxs can be -2, which goes
202-
# out of bounds
203-
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]
204-
205189
# determine only the foreground indices, ignore the rest
206-
foreground_idxs_per_image = matched_idxs_per_image >= 0
207-
num_foreground = foreground_idxs_per_image.sum()
190+
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
191+
num_foreground = foreground_idxs_per_image.numel()
208192

209193
# select only the foreground boxes
210-
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
194+
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
211195
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
212196
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
213197

@@ -403,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
403387
matched_idxs = []
404388
for anchors_per_image, targets_per_image in zip(anchors, targets):
405389
if targets_per_image['boxes'].numel() == 0:
406-
matched_idxs.append(torch.empty((0,), dtype=torch.int32))
390+
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64))
407391
continue
408392

409393
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)

0 commit comments

Comments
 (0)