Skip to content

Commit 24b8643

Browse files
committed
Rename and extend FastRCNNConvFCHead to support arbitrary FCs
1 parent eb649e8 commit 24b8643

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

torchvision/models/detection/faster_rcnn.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -304,31 +304,34 @@ def forward(self, x):
304304
return x
305305

306306

307-
class FastRCNNHeads(nn.Sequential):
307+
class FastRCNNConvFCHead(nn.Sequential):
308308
def __init__(
309309
self,
310310
input_size: Tuple[int, int, int],
311-
layers: List[int],
312-
output_channels: int,
311+
conv_layers: List[int],
312+
fc_layers: List[int],
313313
norm_layer: Optional[Callable[..., nn.Module]] = None,
314314
):
315315
"""
316316
Args:
317317
input_size (Tuple[int, int, int]): the input size in CHW format.
318-
layers (list): feature dimensions of each FCN layer
319-
output_channels (int): output channels
318+
conv_layers (list): feature dimensions of each Convolution layer
319+
fc_layers (list): feature dimensions of each FCN layer
320320
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
321321
"""
322322
in_channels, in_height, in_width = input_size
323323

324324
blocks = []
325325
previous_channels = in_channels
326-
for layer_channels in layers:
327-
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, layer_channels, norm_layer=norm_layer))
328-
previous_channels = layer_channels
326+
for current_channels in conv_layers:
327+
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
328+
previous_channels = current_channels
329329
blocks.append(nn.Flatten())
330-
blocks.append(nn.Linear(previous_channels * in_height * in_width, output_channels))
331-
blocks.append(nn.ReLU(inplace=True))
330+
previous_channels = previous_channels * in_height * in_width
331+
for current_channels in fc_layers:
332+
blocks.append(nn.Linear(previous_channels, current_channels))
333+
blocks.append(nn.ReLU(inplace=True))
334+
previous_channels = current_channels
332335

333336
super().__init__(*blocks)
334337
for layer in self.modules():
@@ -567,7 +570,7 @@ def fasterrcnn_resnet50_fpn_v2(
567570
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
568571
rpn_anchor_generator = _default_anchorgen()
569572
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
570-
box_head = FastRCNNHeads((backbone.out_channels, 7, 7), [256, 256, 256, 256], 1024, norm_layer=nn.BatchNorm2d)
573+
box_head = FastRCNNConvFCHead((backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d)
571574
model = FasterRCNN(
572575
backbone,
573576
num_classes=num_classes,

torchvision/models/detection/mask_rcnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..resnet import ResNet50_Weights, resnet50
1313
from ._utils import overwrite_eps
1414
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
15-
from .faster_rcnn import FasterRCNN, FastRCNNHeads, RPNHead, _default_anchorgen
15+
from .faster_rcnn import FasterRCNN, FastRCNNConvFCHead, RPNHead, _default_anchorgen
1616

1717

1818
__all__ = [
@@ -508,7 +508,7 @@ def maskrcnn_resnet50_fpn_v2(
508508
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
509509
rpn_anchor_generator = _default_anchorgen()
510510
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
511-
box_head = FastRCNNHeads((backbone.out_channels, 7, 7), [256, 256, 256, 256], 1024, norm_layer=nn.BatchNorm2d)
511+
box_head = FastRCNNConvFCHead((backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d)
512512
mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
513513
model = MaskRCNN(
514514
backbone,

0 commit comments

Comments
 (0)