Skip to content

Commit 2f1f578

Browse files
committed
Add experimental resnet50 backbone.
1 parent 730c5e1 commit 2f1f578

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_available_video_models():
4545
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4646
"retinanet_resnet50_fpn": lambda x: x[1],
4747
"ssd300_vgg16": lambda x: x[1],
48+
"ssd512_resnet50": lambda x: x[1],
4849
}
4950

5051

torchvision/models/detection/ssd.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
from .anchor_utils import DefaultBoxGenerator
1111
from .backbone_utils import _validate_trainable_layers
1212
from .transform import GeneralizedRCNNTransform
13-
from .. import vgg
13+
from .. import vgg, resnet
1414
from ..utils import load_state_dict_from_url
1515
from ...ops import boxes as box_ops
1616

17-
__all__ = ['SSD', 'ssd300_vgg16']
17+
__all__ = ['SSD', 'ssd300_vgg16', 'ssd512_resnet50']
1818

1919
model_urls = {
2020
'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth',
21+
'ssd512_resnet50_coco': None, # TODO: add weights
2122
}
2223

2324
backbone_urls = {
@@ -562,3 +563,114 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
562563
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
563564
model.load_state_dict(state_dict)
564565
return model
566+
567+
568+
class SSDFeatureExtractorResNet(nn.Module):
569+
def __init__(self, backbone: resnet.ResNet):
570+
super().__init__()
571+
572+
self.features = nn.Sequential(
573+
backbone.conv1,
574+
backbone.bn1,
575+
backbone.relu,
576+
backbone.maxpool,
577+
backbone.layer1,
578+
backbone.layer2,
579+
backbone.layer3,
580+
backbone.layer4,
581+
)
582+
583+
# Patch last block's strides to get valid output sizes
584+
for m in self.features[-1][0].modules():
585+
if hasattr(m, 'stride'):
586+
m.stride = 1
587+
588+
backbone_out_channels = self.features[-1][-1].bn3.num_features
589+
extra = nn.ModuleList([
590+
nn.Sequential(
591+
nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False),
592+
nn.BatchNorm2d(256),
593+
nn.ReLU(inplace=True),
594+
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
595+
nn.BatchNorm2d(512),
596+
nn.ReLU(inplace=True),
597+
),
598+
nn.Sequential(
599+
nn.Conv2d(512, 256, kernel_size=1, bias=False),
600+
nn.BatchNorm2d(256),
601+
nn.ReLU(inplace=True),
602+
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False),
603+
nn.BatchNorm2d(512),
604+
nn.ReLU(inplace=True),
605+
),
606+
nn.Sequential(
607+
nn.Conv2d(512, 128, kernel_size=1, bias=False),
608+
nn.BatchNorm2d(128),
609+
nn.ReLU(inplace=True),
610+
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False),
611+
nn.BatchNorm2d(256),
612+
nn.ReLU(inplace=True),
613+
),
614+
nn.Sequential(
615+
nn.Conv2d(256, 128, kernel_size=1, bias=False),
616+
nn.BatchNorm2d(128),
617+
nn.ReLU(inplace=True),
618+
nn.Conv2d(128, 256, kernel_size=3, bias=False),
619+
nn.BatchNorm2d(256),
620+
nn.ReLU(inplace=True),
621+
),
622+
nn.Sequential(
623+
nn.Conv2d(256, 128, kernel_size=1, bias=False),
624+
nn.BatchNorm2d(128),
625+
nn.ReLU(inplace=True),
626+
nn.Conv2d(128, 256, kernel_size=2, bias=False),
627+
nn.BatchNorm2d(256),
628+
nn.ReLU(inplace=True),
629+
)
630+
])
631+
_xavier_init(extra)
632+
self.extra = extra
633+
634+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
635+
x = self.features(x)
636+
output = [x]
637+
638+
for block in self.extra:
639+
x = block(x)
640+
output.append(x)
641+
642+
return OrderedDict([(str(i), v) for i, v in enumerate(output)])
643+
644+
645+
def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int):
646+
backbone = resnet.__dict__[backbone_name](pretrained=pretrained)
647+
648+
assert 0 <= trainable_layers <= 5
649+
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
650+
if trainable_layers == 5:
651+
layers_to_train.append('bn1')
652+
for name, parameter in backbone.named_parameters():
653+
if all([not name.startswith(layer) for layer in layers_to_train]):
654+
parameter.requires_grad_(False)
655+
656+
return SSDFeatureExtractorResNet(backbone)
657+
658+
659+
def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
660+
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any):
661+
trainable_backbone_layers = _validate_trainable_layers(
662+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5)
663+
664+
if pretrained:
665+
pretrained_backbone = False
666+
667+
backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers)
668+
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]])
669+
model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs)
670+
if pretrained:
671+
weights_name = 'ssd512_resnet50_coco'
672+
if model_urls.get(weights_name, None) is None:
673+
raise ValueError("No checkpoint is available for model {}".format(weights_name))
674+
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
675+
model.load_state_dict(state_dict)
676+
return model

0 commit comments

Comments
 (0)