Skip to content

Commit 2b555f6

Browse files
yiwen-songzhiqwangjdsgomes
authored andcommitted
[fbsync] fix bug when the target is empty in FCOS (#5267)
Summary: * fix bug when the target is empty * Add unittest for empty instance training Reviewed By: kazhang Differential Revision: D33927512 fbshipit-source-id: e92355380948d9181e135b7612596c5309afeeda Co-authored-by: Zhiqiang Wang <[email protected]> Co-authored-by: Joao Gomes <[email protected]>
1 parent 31408dd commit 2b555f6

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

test/test_models_detection_negative_samples.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ def test_forward_negative_sample_retinanet(self):
143143

144144
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
145145

146+
def test_forward_negative_sample_fcos(self):
147+
model = torchvision.models.detection.fcos_resnet50_fpn(
148+
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False
149+
)
150+
151+
images, targets = self._make_empty_sample()
152+
loss_dict = model(images, targets)
153+
154+
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
155+
assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0))
156+
146157
def test_forward_negative_sample_ssd(self):
147158
model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False)
148159

torchvision/models/detection/fcos.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ def compute_loss(
5959
all_gt_classes_targets = []
6060
all_gt_boxes_targets = []
6161
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
62-
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
62+
if len(targets_per_image["labels"]) == 0:
63+
gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
64+
gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
65+
else:
66+
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
67+
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
6368
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
64-
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
6569
all_gt_classes_targets.append(gt_classes_targets)
6670
all_gt_boxes_targets.append(gt_boxes_targets)
6771

@@ -95,13 +99,14 @@ def compute_loss(
9599
]
96100
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
97101
if len(bbox_reg_targets) == 0:
98-
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
99-
left_right = bbox_reg_targets[:, :, [0, 2]]
100-
top_bottom = bbox_reg_targets[:, :, [1, 3]]
101-
gt_ctrness_targets = torch.sqrt(
102-
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
103-
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
104-
)
102+
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
103+
else:
104+
left_right = bbox_reg_targets[:, :, [0, 2]]
105+
top_bottom = bbox_reg_targets[:, :, [1, 3]]
106+
gt_ctrness_targets = torch.sqrt(
107+
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
108+
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
109+
)
105110
pred_centerness = bbox_ctrness.squeeze(dim=2)
106111
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
107112
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"

0 commit comments

Comments
 (0)