@@ -42,9 +42,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
4242 }
4343
4444 def forward (self , x ):
45- cls_logits = [self .classification_head (feature ) for feature in x ]
46- bbox_reg = [self .regression_head (feature ) for feature in x ]
47- return dict (cls_logits = cls_logits , bbox_reg = bbox_reg )
45+ return {
46+ 'cls_logits' : self .classification_head (x ),
47+ 'bbox_regression' : self .regression_head (x )
48+ }
4849
4950
5051def sigmoid_focal_loss (
@@ -127,45 +128,48 @@ def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01
127128 def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
128129 loss = []
129130
130- def permute_classification (tensor ):
131- """ Permute classification output from (N, A * K, H, W) to (N, HWA, K). """
132- N , _ , H , W = tensor .shape
133- tensor = tensor .view (N , - 1 , self .num_classes , H , W )
134- tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
135- tensor = tensor .reshape (N , - 1 , self .num_classes ) # Size=(N, HWA, 4)
136- return tensor
137-
138- predicted_classification = head_outputs ['cls_logits' ]
139- predicted_classification = [permute_classification (cls ) for cls in predicted_classification ]
140- predicted_classification = torch .cat (predicted_classification , dim = 1 )
131+ cls_logits = head_outputs ['cls_logits' ]
141132
142- for targets_per_image , predicted_classification_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_classification , anchors , matched_idxs ):
133+ for targets_per_image , cls_logits_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , cls_logits , anchors , matched_idxs ):
143134 # determine only the foreground
144135 foreground_idxs_per_image = matched_idxs_per_image >= 0
145136 num_foreground = foreground_idxs_per_image .sum ()
146137
147138 # create the target classification
148- gt_classes_target = torch .zeros_like (predicted_classification_per_image )
139+ gt_classes_target = torch .zeros_like (cls_logits_per_image )
149140 gt_classes_target [foreground_idxs_per_image , targets_per_image ['labels' ][matched_idxs_per_image [foreground_idxs_per_image ]]] = 1
150141
151142 # find indices for which anchors should be ignored
152143 valid_idxs_per_image = matched_idxs_per_image != det_utils .Matcher .BETWEEN_THRESHOLDS
153144
154145 # compute the classification loss
155- loss .append (sigmoid_focal_loss_jit (
156- predicted_classification_per_image [valid_idxs_per_image ],
146+ loss .append (sigmoid_focal_loss (
147+ cls_logits_per_image [valid_idxs_per_image ],
157148 gt_classes_target [valid_idxs_per_image ],
158149 reduction = 'sum' ,
159150 ) / max (1 , num_foreground ))
160151
161152 return sum (loss ) / len (loss )
162153
163154 def forward (self , x ):
164- x = F .relu (self .conv1 (x ))
165- x = F .relu (self .conv2 (x ))
166- x = F .relu (self .conv3 (x ))
167- x = F .relu (self .conv4 (x ))
168- return self .cls_logits (x )
155+ all_cls_logits = []
156+
157+ for features in x :
158+ cls_logits = F .relu (self .conv1 (features ))
159+ cls_logits = F .relu (self .conv2 (cls_logits ))
160+ cls_logits = F .relu (self .conv3 (cls_logits ))
161+ cls_logits = F .relu (self .conv4 (cls_logits ))
162+ cls_logits = self .cls_logits (cls_logits )
163+
164+ # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
165+ N , _ , H , W = cls_logits .shape
166+ cls_logits = cls_logits .view (N , - 1 , self .num_classes , H , W )
167+ cls_logits = cls_logits .permute (0 , 3 , 4 , 1 , 2 )
168+ cls_logits = cls_logits .reshape (N , - 1 , self .num_classes ) # Size=(N, HWA, 4)
169+
170+ all_cls_logits .append (cls_logits )
171+
172+ return torch .cat (all_cls_logits , dim = 1 )
169173
170174
171175class RetinaNetRegressionHead (nn .Module ):
@@ -194,19 +198,9 @@ def __init__(self, in_channels, num_anchors):
194198 def compute_loss (self , targets , head_outputs , anchors , matched_idxs ):
195199 loss = []
196200
197- def permute_bbox_reg (tensor ):
198- """ Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). """
199- N , _ , H , W = tensor .shape
200- tensor = tensor .view (N , - 1 , 4 , H , W )
201- tensor = tensor .permute (0 , 3 , 4 , 1 , 2 )
202- tensor = tensor .reshape (N , - 1 , 4 ) # Size=(N, HWA, 4)
203- return tensor
204-
205- predicted_regression = head_outputs ['bbox_reg' ]
206- predicted_regression = [permute_bbox_reg (reg ) for reg in predicted_regression ]
207- predicted_regression = torch .cat (predicted_regression , dim = 1 )
201+ bbox_regression = head_outputs ['bbox_regression' ]
208202
209- for targets_per_image , predicted_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , predicted_regression , anchors , matched_idxs ):
203+ for targets_per_image , bbox_regression_per_image , anchors_per_image , matched_idxs_per_image in zip (targets , bbox_regression , anchors , matched_idxs ):
210204 # get the targets corresponding GT for each proposal
211205 # NB: need to clamp the indices because we can have a single
212206 # GT in the image, and matched_idxs can be -2, which goes
@@ -219,23 +213,36 @@ def permute_bbox_reg(tensor):
219213
220214 # select only the foreground boxes
221215 matched_gt_boxes_per_image = matched_gt_boxes_per_image [foreground_idxs_per_image , :]
222- predicted_regression_per_image = predicted_regression_per_image [foreground_idxs_per_image , :]
216+ bbox_regression_per_image = bbox_regression_per_image [foreground_idxs_per_image , :]
223217 anchors_per_image = anchors_per_image [foreground_idxs_per_image , :]
224218
225219 # compute the regression targets
226220 target_regression = self .box_coder .encode_single (matched_gt_boxes_per_image , anchors_per_image )
227221
228222 # compute the loss
229- loss .append (torch .nn .SmoothL1Loss (reduction = 'sum' )(predicted_regression_per_image , target_regression ) / max (1 , num_foreground ))
223+ loss .append (torch .nn .SmoothL1Loss (reduction = 'sum' )(bbox_regression_per_image , target_regression ) / max (1 , num_foreground ))
230224
231225 return sum (loss ) / max (1 , len (loss ))
232226
233227 def forward (self , x ):
234- x = F .relu (self .conv1 (x ))
235- x = F .relu (self .conv2 (x ))
236- x = F .relu (self .conv3 (x ))
237- x = F .relu (self .conv4 (x ))
238- return self .bbox_reg (x )
228+ all_bbox_regression = []
229+
230+ for features in x :
231+ bbox_regression = F .relu (self .conv1 (features ))
232+ bbox_regression = F .relu (self .conv2 (bbox_regression ))
233+ bbox_regression = F .relu (self .conv3 (bbox_regression ))
234+ bbox_regression = F .relu (self .conv4 (bbox_regression ))
235+ bbox_regression = self .bbox_reg (bbox_regression )
236+
237+ # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
238+ N , _ , H , W = bbox_regression .shape
239+ bbox_regression = bbox_regression .view (N , - 1 , 4 , H , W )
240+ bbox_regression = bbox_regression .permute (0 , 3 , 4 , 1 , 2 )
241+ bbox_regression = bbox_regression .reshape (N , - 1 , 4 ) # Size=(N, HWA, 4)
242+
243+ all_bbox_regression .append (bbox_regression )
244+
245+ return torch .cat (all_bbox_regression , dim = 1 )
239246
240247
241248class RetinaNet (nn .Module ):
0 commit comments