Skip to content

Commit f883796

Browse files
authored
RetinaNet with MobileNetV3 FPN backbone (#3223)
* Initial implementation of MobileNet + RetinaNet. * Adding expeted file for mobilenetv3_large. * Adding temp candidate pretrained model for mobilenetv3 large. * Adding temp candidate pretrained model for mobilenetv3 large. * making `_validate_resnet_trainable_layers` generic and using it in mobilenet * Update tests. * Rename output_channels to out_channels * Fixing comments. * Update model. * Better mobilenetv3 large backbone. * Better mobilenetv3 + retinanet model. * Update mobilenetv3 + retinanet model.
1 parent 81bd2b3 commit f883796

11 files changed

+127
-38
lines changed
Binary file not shown.

test/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def get_available_video_models():
4040
"maskrcnn_resnet50_fpn": lambda x: x[1],
4141
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4242
"retinanet_resnet50_fpn": lambda x: x[1],
43+
"retinanet_mobilenet_v3_large_fpn": lambda x: x[1],
4344
}
4445

4546

@@ -104,7 +105,7 @@ def _test_detection_model(self, name, dev):
104105
kwargs = {}
105106
if "retinanet" in name:
106107
# Reduce the default threshold to ensure the returned boxes are not empty.
107-
kwargs["score_thresh"] = 0.01
108+
kwargs["score_thresh"] = 0.0099999
108109
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
109110
model.eval().to(device=dev)
110111
input_shape = (3, 300, 300)

test/test_models_detection_negative_samples.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,14 @@ def test_forward_negative_sample_krcnn(self):
129129
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))
130130

131131
def test_forward_negative_sample_retinanet(self):
132-
model = torchvision.models.detection.retinanet_resnet50_fpn(
133-
num_classes=2, min_size=100, max_size=100)
132+
for name in ["retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn"]:
133+
model = torchvision.models.detection.__dict__[name](
134+
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)
134135

135-
images, targets = self._make_empty_sample()
136-
loss_dict = model(images, targets)
136+
images, targets = self._make_empty_sample()
137+
loss_dict = model(images, targets)
137138

138-
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))
139+
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))
139140

140141

141142
if __name__ == '__main__':

test/test_models_detection_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self):
3636

3737
def test_validate_resnet_inputs_detection(self):
3838
# default number of backbone layers to train
39-
ret = backbone_utils._validate_resnet_trainable_layers(
40-
pretrained=True, trainable_backbone_layers=None)
39+
ret = backbone_utils._validate_trainable_layers(
40+
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
4141
self.assertEqual(ret, 3)
4242
# can't go beyond 5
4343
with self.assertRaises(AssertionError):
44-
ret = backbone_utils._validate_resnet_trainable_layers(
45-
pretrained=True, trainable_backbone_layers=6)
44+
ret = backbone_utils._validate_trainable_layers(
45+
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
4646
# if not pretrained, should use all trainable layers and warn
4747
with self.assertWarns(UserWarning):
48-
ret = backbone_utils._validate_resnet_trainable_layers(
49-
pretrained=False, trainable_backbone_layers=0)
48+
ret = backbone_utils._validate_trainable_layers(
49+
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
5050
self.assertEqual(ret, 5)
5151

5252
def test_transform_copy_targets(self):

torchvision/models/detection/backbone_utils.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from collections import OrderedDict
32
from torch import nn
43
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
54

65
from torchvision.ops import misc as misc_nn_ops
76
from .._utils import IntermediateLayerGetter
7+
from .. import mobilenet
88
from .. import resnet
99

1010

@@ -108,17 +108,55 @@ def resnet_fpn_backbone(
108108
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
109109

110110

111-
def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers):
111+
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
112112
# dont freeze any layers if pretrained model or backbone is not used
113113
if not pretrained:
114114
if trainable_backbone_layers is not None:
115115
warnings.warn(
116116
"Changing trainable_backbone_layers has not effect if "
117117
"neither pretrained nor pretrained_backbone have been set to True, "
118-
"falling back to trainable_backbone_layers=5 so that all layers are trainable")
119-
trainable_backbone_layers = 5
120-
# by default, freeze first 2 blocks following Faster R-CNN
118+
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
119+
trainable_backbone_layers = max_value
120+
121+
# by default freeze first blocks
121122
if trainable_backbone_layers is None:
122-
trainable_backbone_layers = 3
123-
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
123+
trainable_backbone_layers = default_value
124+
assert 0 <= trainable_backbone_layers <= max_value
124125
return trainable_backbone_layers
126+
127+
128+
def mobilenet_fpn_backbone(
129+
backbone_name,
130+
pretrained,
131+
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
132+
trainable_layers=2,
133+
returned_layers=None,
134+
extra_blocks=None
135+
):
136+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
137+
138+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
139+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
140+
stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
141+
num_stages = len(stage_indeces)
142+
143+
# find the index of the layer from which we wont freeze
144+
assert 0 <= trainable_layers <= num_stages
145+
freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers]
146+
147+
# freeze layers only if pretrained backbone is used
148+
for b in backbone[:freeze_before]:
149+
for parameter in b.parameters():
150+
parameter.requires_grad_(False)
151+
152+
if extra_blocks is None:
153+
extra_blocks = LastLevelMaxPool()
154+
155+
if returned_layers is None:
156+
returned_layers = [num_stages - 2, num_stages - 1]
157+
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
158+
return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)}
159+
160+
in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers]
161+
out_channels = 256
162+
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)

torchvision/models/detection/faster_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .rpn import RPNHead, RegionProposalNetwork
1616
from .roi_heads import RoIHeads
1717
from .transform import GeneralizedRCNNTransform
18-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
18+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1919

2020

2121
__all__ = [
@@ -350,8 +350,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
350350
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
351351
"""
352352
# check default parameters and by default set it to 3 if possible
353-
trainable_backbone_layers = _validate_resnet_trainable_layers(
354-
pretrained or pretrained_backbone, trainable_backbone_layers)
353+
trainable_backbone_layers = _validate_trainable_layers(
354+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
355355

356356
if pretrained:
357357
# no need to download the backbone if pretrained is set

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..utils import load_state_dict_from_url
88

99
from .faster_rcnn import FasterRCNN
10-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
10+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1111

1212

1313
__all__ = [
@@ -319,8 +319,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
319319
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
320320
"""
321321
# check default parameters and by default set it to 3 if possible
322-
trainable_backbone_layers = _validate_resnet_trainable_layers(
323-
pretrained or pretrained_backbone, trainable_backbone_layers)
322+
trainable_backbone_layers = _validate_trainable_layers(
323+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
324324

325325
if pretrained:
326326
# no need to download the backbone if pretrained is set

torchvision/models/detection/mask_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..utils import load_state_dict_from_url
99

1010
from .faster_rcnn import FasterRCNN
11-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
11+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1212

1313
__all__ = [
1414
"MaskRCNN", "maskrcnn_resnet50_fpn",
@@ -314,8 +314,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
314314
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
315315
"""
316316
# check default parameters and by default set it to 3 if possible
317-
trainable_backbone_layers = _validate_resnet_trainable_layers(
318-
pretrained or pretrained_backbone, trainable_backbone_layers)
317+
trainable_backbone_layers = _validate_trainable_layers(
318+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
319319

320320
if pretrained:
321321
# no need to download the backbone if pretrained is set

torchvision/models/detection/retinanet.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
from . import _utils as det_utils
1313
from .anchor_utils import AnchorGenerator
1414
from .transform import GeneralizedRCNNTransform
15-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
15+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_fpn_backbone
1616
from ...ops.feature_pyramid_network import LastLevelP6P7
1717
from ...ops import sigmoid_focal_loss
1818
from ...ops import boxes as box_ops
1919

2020

2121
__all__ = [
22-
"RetinaNet", "retinanet_resnet50_fpn",
22+
"RetinaNet", "retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn"
2323
]
2424

2525

@@ -557,7 +557,10 @@ def forward(self, images, targets=None):
557557
return self.eager_outputs(losses, detections)
558558

559559

560+
# TODO: replace with pytorch links
560561
model_urls = {
562+
'retinanet_mobilenet_v3_large_fpn_coco':
563+
'https://github.com/datumbox/torchvision-models/raw/main/retinanet_mobilenet_v3_large_fpn-41c847a4.pth',
561564
'retinanet_resnet50_fpn_coco':
562565
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth',
563566
}
@@ -606,8 +609,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
606609
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
607610
"""
608611
# check default parameters and by default set it to 3 if possible
609-
trainable_backbone_layers = _validate_resnet_trainable_layers(
610-
pretrained or pretrained_backbone, trainable_backbone_layers)
612+
trainable_backbone_layers = _validate_trainable_layers(
613+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
611614

612615
if pretrained:
613616
# no need to download the backbone if pretrained is set
@@ -622,3 +625,44 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
622625
model.load_state_dict(state_dict)
623626
overwrite_eps(model, 0.0)
624627
return model
628+
629+
630+
def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
631+
trainable_backbone_layers=None, **kwargs):
632+
"""
633+
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly
634+
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details.
635+
636+
Example::
637+
638+
>>> model = torchvision.models.detection.retinanet_mobilenet_v3_large_fpn(pretrained=True)
639+
>>> model.eval()
640+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
641+
>>> predictions = model(x)
642+
643+
Args:
644+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
645+
progress (bool): If True, displays a progress bar of the download to stderr
646+
num_classes (int): number of output classes of the model (including the background)
647+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
648+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
649+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
650+
"""
651+
# check default parameters and by default set it to 3 if possible
652+
trainable_backbone_layers = _validate_trainable_layers(
653+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
654+
655+
if pretrained:
656+
pretrained_backbone = False
657+
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5],
658+
trainable_layers=trainable_backbone_layers)
659+
660+
anchor_sizes = ((128,), (256,), (512,))
661+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
662+
663+
model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs)
664+
if pretrained:
665+
state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'],
666+
progress=progress)
667+
model.load_state_dict(state_dict)
668+
return model

torchvision/models/mobilenetv2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
norm_layer(out_planes),
5454
activation_layer(inplace=True)
5555
)
56+
self.out_channels = out_planes
5657

5758

5859
# necessary for backwards compatibility
@@ -90,6 +91,8 @@ def __init__(
9091
norm_layer(oup),
9192
])
9293
self.conv = nn.Sequential(*layers)
94+
self.out_channels = oup
95+
self.is_strided = stride > 1
9396

9497
def forward(self, x: Tensor) -> Tensor:
9598
if self.use_res_connect:

torchvision/models/mobilenetv3.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# TODO: add pretrained
1616
model_urls = {
17-
"mobilenet_v3_large": None,
17+
"mobilenet_v3_large": "https://github.com/datumbox/torchvision-models/raw/main/mobilenet_v3_large-8738ca79.pth",
1818
"mobilenet_v3_small": None,
1919
}
2020

@@ -48,12 +48,12 @@ def forward(self, input: Tensor) -> Tensor:
4848

4949
class InvertedResidualConfig:
5050

51-
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool,
51+
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
5252
activation: str, stride: int, width_mult: float):
5353
self.input_channels = self.adjust_channels(input_channels, width_mult)
5454
self.kernel = kernel
5555
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
56-
self.output_channels = self.adjust_channels(output_channels, width_mult)
56+
self.out_channels = self.adjust_channels(out_channels, width_mult)
5757
self.use_se = use_se
5858
self.use_hs = activation == "HS"
5959
self.stride = stride
@@ -70,7 +70,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
7070
if not (1 <= cnf.stride <= 2):
7171
raise ValueError('illegal stride value')
7272

73-
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels
73+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
7474

7575
layers: List[nn.Module] = []
7676
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
@@ -88,10 +88,12 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
8888
layers.append(SqueezeExcitation(cnf.expanded_channels))
8989

9090
# project
91-
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.output_channels, kernel_size=1, norm_layer=norm_layer,
91+
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
9292
activation_layer=Identity))
9393

9494
self.block = nn.Sequential(*layers)
95+
self.out_channels = cnf.out_channels
96+
self.is_strided = cnf.stride > 1
9597

9698
def forward(self, input: Tensor) -> Tensor:
9799
result = self.block(input)
@@ -146,7 +148,7 @@ def __init__(
146148
layers.append(block(cnf, norm_layer))
147149

148150
# building last several layers
149-
lastconv_input_channels = inverted_residual_setting[-1].output_channels
151+
lastconv_input_channels = inverted_residual_setting[-1].out_channels
150152
lastconv_output_channels = 6 * lastconv_input_channels
151153
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
152154
norm_layer=norm_layer, activation_layer=nn.Hardswish))

0 commit comments

Comments
 (0)