Skip to content

Commit 1451d97

Browse files
authored
fix bug when the target is empty
1 parent 44ae1e5 commit 1451d97

File tree

1 file changed

+14
-9
lines changed
  • torchvision/models/detection

1 file changed

+14
-9
lines changed

torchvision/models/detection/fcos.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ def compute_loss(
5959
all_gt_classes_targets = []
6060
all_gt_boxes_targets = []
6161
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
62-
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
62+
if len(targets_per_image["labels"]) == 0:
63+
gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
64+
gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
65+
else:
66+
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
67+
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
6368
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
64-
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
6569
all_gt_classes_targets.append(gt_classes_targets)
6670
all_gt_boxes_targets.append(gt_boxes_targets)
6771

@@ -95,13 +99,14 @@ def compute_loss(
9599
]
96100
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
97101
if len(bbox_reg_targets) == 0:
98-
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
99-
left_right = bbox_reg_targets[:, :, [0, 2]]
100-
top_bottom = bbox_reg_targets[:, :, [1, 3]]
101-
gt_ctrness_targets = torch.sqrt(
102-
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
103-
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
104-
)
102+
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
103+
else:
104+
left_right = bbox_reg_targets[:, :, [0, 2]]
105+
top_bottom = bbox_reg_targets[:, :, [1, 3]]
106+
gt_ctrness_targets = torch.sqrt(
107+
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
108+
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
109+
)
105110
pred_centerness = bbox_ctrness.squeeze(dim=2)
106111
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
107112
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"

0 commit comments

Comments
 (0)