@@ -59,9 +59,13 @@ def compute_loss(
59
59
all_gt_classes_targets = []
60
60
all_gt_boxes_targets = []
61
61
for targets_per_image , matched_idxs_per_image in zip (targets , matched_idxs ):
62
- gt_classes_targets = targets_per_image ["labels" ][matched_idxs_per_image .clip (min = 0 )]
62
+ if len (targets_per_image ["labels" ]) == 0 :
63
+ gt_classes_targets = targets_per_image ["labels" ].new_zeros ((len (matched_idxs_per_image ),))
64
+ gt_boxes_targets = targets_per_image ["boxes" ].new_zeros ((len (matched_idxs_per_image ), 4 ))
65
+ else :
66
+ gt_classes_targets = targets_per_image ["labels" ][matched_idxs_per_image .clip (min = 0 )]
67
+ gt_boxes_targets = targets_per_image ["boxes" ][matched_idxs_per_image .clip (min = 0 )]
63
68
gt_classes_targets [matched_idxs_per_image < 0 ] = - 1 # backgroud
64
- gt_boxes_targets = targets_per_image ["boxes" ][matched_idxs_per_image .clip (min = 0 )]
65
69
all_gt_classes_targets .append (gt_classes_targets )
66
70
all_gt_boxes_targets .append (gt_boxes_targets )
67
71
@@ -95,13 +99,14 @@ def compute_loss(
95
99
]
96
100
bbox_reg_targets = torch .stack (bbox_reg_targets , dim = 0 )
97
101
if len (bbox_reg_targets ) == 0 :
98
- bbox_reg_targets .new_zeros (len (bbox_reg_targets ))
99
- left_right = bbox_reg_targets [:, :, [0 , 2 ]]
100
- top_bottom = bbox_reg_targets [:, :, [1 , 3 ]]
101
- gt_ctrness_targets = torch .sqrt (
102
- (left_right .min (dim = - 1 )[0 ] / left_right .max (dim = - 1 )[0 ])
103
- * (top_bottom .min (dim = - 1 )[0 ] / top_bottom .max (dim = - 1 )[0 ])
104
- )
102
+ gt_ctrness_targets = bbox_reg_targets .new_zeros (bbox_reg_targets .size ()[:- 1 ])
103
+ else :
104
+ left_right = bbox_reg_targets [:, :, [0 , 2 ]]
105
+ top_bottom = bbox_reg_targets [:, :, [1 , 3 ]]
106
+ gt_ctrness_targets = torch .sqrt (
107
+ (left_right .min (dim = - 1 )[0 ] / left_right .max (dim = - 1 )[0 ])
108
+ * (top_bottom .min (dim = - 1 )[0 ] / top_bottom .max (dim = - 1 )[0 ])
109
+ )
105
110
pred_centerness = bbox_ctrness .squeeze (dim = 2 )
106
111
loss_bbox_ctrness = nn .functional .binary_cross_entropy_with_logits (
107
112
pred_centerness [foregroud_mask ], gt_ctrness_targets [foregroud_mask ], reduction = "sum"
0 commit comments