@@ -304,31 +304,34 @@ def forward(self, x):
304
304
return x
305
305
306
306
307
- class FastRCNNHeads (nn .Sequential ):
307
+ class FastRCNNConvFCHead (nn .Sequential ):
308
308
def __init__ (
309
309
self ,
310
310
input_size : Tuple [int , int , int ],
311
- layers : List [int ],
312
- output_channels : int ,
311
+ conv_layers : List [int ],
312
+ fc_layers : List [ int ] ,
313
313
norm_layer : Optional [Callable [..., nn .Module ]] = None ,
314
314
):
315
315
"""
316
316
Args:
317
317
input_size (Tuple[int, int, int]): the input size in CHW format.
318
- layers (list): feature dimensions of each FCN layer
319
- output_channels (int ): output channels
318
+ conv_layers (list): feature dimensions of each Convolution layer
319
+ fc_layers (list ): feature dimensions of each FCN layer
320
320
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
321
321
"""
322
322
in_channels , in_height , in_width = input_size
323
323
324
324
blocks = []
325
325
previous_channels = in_channels
326
- for layer_channels in layers :
327
- blocks .append (misc_nn_ops .Conv2dNormActivation (previous_channels , layer_channels , norm_layer = norm_layer ))
328
- previous_channels = layer_channels
326
+ for current_channels in conv_layers :
327
+ blocks .append (misc_nn_ops .Conv2dNormActivation (previous_channels , current_channels , norm_layer = norm_layer ))
328
+ previous_channels = current_channels
329
329
blocks .append (nn .Flatten ())
330
- blocks .append (nn .Linear (previous_channels * in_height * in_width , output_channels ))
331
- blocks .append (nn .ReLU (inplace = True ))
330
+ previous_channels = previous_channels * in_height * in_width
331
+ for current_channels in fc_layers :
332
+ blocks .append (nn .Linear (previous_channels , current_channels ))
333
+ blocks .append (nn .ReLU (inplace = True ))
334
+ previous_channels = current_channels
332
335
333
336
super ().__init__ (* blocks )
334
337
for layer in self .modules ():
@@ -567,7 +570,7 @@ def fasterrcnn_resnet50_fpn_v2(
567
570
backbone = _resnet_fpn_extractor (backbone , trainable_backbone_layers , norm_layer = nn .BatchNorm2d )
568
571
rpn_anchor_generator = _default_anchorgen ()
569
572
rpn_head = RPNHead (backbone .out_channels , rpn_anchor_generator .num_anchors_per_location ()[0 ], conv_depth = 2 )
570
- box_head = FastRCNNHeads ((backbone .out_channels , 7 , 7 ), [256 , 256 , 256 , 256 ], 1024 , norm_layer = nn .BatchNorm2d )
573
+ box_head = FastRCNNConvFCHead ((backbone .out_channels , 7 , 7 ), [256 , 256 , 256 , 256 ], [ 1024 ] , norm_layer = nn .BatchNorm2d )
571
574
model = FasterRCNN (
572
575
backbone ,
573
576
num_classes = num_classes ,
0 commit comments