Skip to content

Commit bf211da

Browse files
authored
Add MobileNetV3 architecture for Detection (#3253)
* Minor refactoring of a private method to make it reusuable. * Adding a FasterRCNN + MobileNetV3 with & w/o FPN models. * Reducing Resolution to 320-640 and anchor sizes to 16-256. * Increase anchor sizes. * Adding rpn score threshold param on the train script. * Adding trainable_backbone_layers param on the train script. * Adding rpn_score_thresh param directly in fasterrcnn_mobilenet_v3_large_fpn. * Remove fasterrcnn_mobilenet_v3_large prototype and update expected file. * Update documentation and adding weights. * Use buildin Identity. * Fix spelling.
1 parent 0985533 commit bf211da

13 files changed

+168
-70
lines changed

docs/source/models.rst

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,14 @@ models return the predictions of the following classes:
358358
Here are the summary of the accuracies for the models trained on
359359
the instances set of COCO train2017 and evaluated on COCO val2017.
360360

361-
================================ ======= ======== ===========
362-
Network box AP mask AP keypoint AP
363-
================================ ======= ======== ===========
364-
Faster R-CNN ResNet-50 FPN 37.0 - -
365-
RetinaNet ResNet-50 FPN 36.4 - -
366-
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
367-
================================ ======= ======== ===========
361+
================================== ======= ======== ===========
362+
Network box AP mask AP keypoint AP
363+
================================== ======= ======== ===========
364+
Faster R-CNN ResNet-50 FPN 37.0 - -
365+
Faster R-CNN MobileNetV3-Large FPN 23.0 - -
366+
RetinaNet ResNet-50 FPN 36.4 - -
367+
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
368+
================================== ======= ======== ===========
368369

369370
For person keypoint detection, the accuracies for the pre-trained
370371
models are as follows
@@ -414,20 +415,22 @@ For test time, we report the time for the model evaluation and postprocessing
414415
(including mask pasting in image), but not the time for computing the
415416
precision-recall.
416417

417-
============================== =================== ================== ===========
418-
Network train time (s / it) test time (s / it) memory (GB)
419-
============================== =================== ================== ===========
420-
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
421-
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
422-
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
423-
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
424-
============================== =================== ================== ===========
418+
================================== =================== ================== ===========
419+
Network train time (s / it) test time (s / it) memory (GB)
420+
================================== =================== ================== ===========
421+
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
422+
Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6
423+
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
424+
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
425+
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
426+
================================== =================== ================== ===========
425427

426428

427429
Faster R-CNN
428430
------------
429431

430432
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
433+
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
431434

432435

433436
RetinaNet

references/detection/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@ You must modify the following flags:
2020

2121
Except otherwise noted, all models have been trained on 8x V100 GPUs.
2222

23-
### Faster R-CNN
23+
### Faster R-CNN ResNet-50 FPN
2424
```
2525
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
2626
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
2727
--lr-steps 16 22 --aspect-ratio-group-factor 3
2828
```
2929

30+
### Faster R-CNN MobileNetV3-Large FPN
31+
```
32+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
33+
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
34+
--lr-steps 16 22 --aspect-ratio-group-factor 3
35+
```
36+
3037
### RetinaNet
3138
```
3239
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\

references/detection/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ def main(args):
9292
collate_fn=utils.collate_fn)
9393

9494
print("Creating model")
95-
kwargs = {}
95+
kwargs = {
96+
"trainable_backbone_layers": args.trainable_backbone_layers
97+
}
9698
if "rcnn" in args.model:
97-
kwargs["rpn_score_thresh"] = 0.0
99+
if args.rpn_score_thresh is not None:
100+
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
98101
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
99102
**kwargs)
100103
model.to(device)
@@ -177,6 +180,9 @@ def main(args):
177180
parser.add_argument('--resume', default='', help='resume from checkpoint')
178181
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
179182
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
183+
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
184+
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
185+
help='number of trainable layers of backbone')
180186
parser.add_argument(
181187
"--test-only",
182188
dest="test_only",
Binary file not shown.

test/test_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_available_video_models():
3737
'googlenet': lambda x: x.logits,
3838
'inception_v3': lambda x: x.logits,
3939
"fasterrcnn_resnet50_fpn": lambda x: x[1],
40+
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4041
"maskrcnn_resnet50_fpn": lambda x: x[1],
4142
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4243
"retinanet_resnet50_fpn": lambda x: x[1],
@@ -105,6 +106,8 @@ def _test_detection_model(self, name, dev):
105106
if "retinanet" in name:
106107
# Reduce the default threshold to ensure the returned boxes are not empty.
107108
kwargs["score_thresh"] = 0.01
109+
elif "fasterrcnn_mobilenet" in name:
110+
kwargs["box_score_thresh"] = 0.02076
108111
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
109112
model.eval().to(device=dev)
110113
input_shape = (3, 300, 300)

test/test_models_detection_negative_samples.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,15 @@ def test_assign_targets_to_proposals(self):
9797
self.assertEqual(labels[0].dtype, torch.int64)
9898

9999
def test_forward_negative_sample_frcnn(self):
100-
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
101-
num_classes=2, min_size=100, max_size=100)
100+
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]:
101+
model = torchvision.models.detection.__dict__[name](
102+
num_classes=2, min_size=100, max_size=100)
102103

103-
images, targets = self._make_empty_sample()
104-
loss_dict = model(images, targets)
104+
images, targets = self._make_empty_sample()
105+
loss_dict = model(images, targets)
105106

106-
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
107-
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
107+
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
108+
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
108109

109110
def test_forward_negative_sample_mrcnn(self):
110111
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
@@ -130,7 +131,7 @@ def test_forward_negative_sample_krcnn(self):
130131

131132
def test_forward_negative_sample_retinanet(self):
132133
model = torchvision.models.detection.retinanet_resnet50_fpn(
133-
num_classes=2, min_size=100, max_size=100)
134+
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)
134135

135136
images, targets = self._make_empty_sample()
136137
loss_dict = model(images, targets)

test/test_models_detection_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self):
3636

3737
def test_validate_resnet_inputs_detection(self):
3838
# default number of backbone layers to train
39-
ret = backbone_utils._validate_resnet_trainable_layers(
40-
pretrained=True, trainable_backbone_layers=None)
39+
ret = backbone_utils._validate_trainable_layers(
40+
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
4141
self.assertEqual(ret, 3)
4242
# can't go beyond 5
4343
with self.assertRaises(AssertionError):
44-
ret = backbone_utils._validate_resnet_trainable_layers(
45-
pretrained=True, trainable_backbone_layers=6)
44+
ret = backbone_utils._validate_trainable_layers(
45+
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
4646
# if not pretrained, should use all trainable layers and warn
4747
with self.assertWarns(UserWarning):
48-
ret = backbone_utils._validate_resnet_trainable_layers(
49-
pretrained=False, trainable_backbone_layers=0)
48+
ret = backbone_utils._validate_trainable_layers(
49+
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
5050
self.assertEqual(ret, 5)
5151

5252
def test_transform_copy_targets(self):

torchvision/models/detection/backbone_utils.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from collections import OrderedDict
32
from torch import nn
43
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
54

65
from torchvision.ops import misc as misc_nn_ops
76
from .._utils import IntermediateLayerGetter
7+
from .. import mobilenet
88
from .. import resnet
99

1010

@@ -108,17 +108,65 @@ def resnet_fpn_backbone(
108108
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
109109

110110

111-
def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers):
111+
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
112112
# dont freeze any layers if pretrained model or backbone is not used
113113
if not pretrained:
114114
if trainable_backbone_layers is not None:
115115
warnings.warn(
116116
"Changing trainable_backbone_layers has not effect if "
117117
"neither pretrained nor pretrained_backbone have been set to True, "
118-
"falling back to trainable_backbone_layers=5 so that all layers are trainable")
119-
trainable_backbone_layers = 5
120-
# by default, freeze first 2 blocks following Faster R-CNN
118+
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
119+
trainable_backbone_layers = max_value
120+
121+
# by default freeze first blocks
121122
if trainable_backbone_layers is None:
122-
trainable_backbone_layers = 3
123-
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
123+
trainable_backbone_layers = default_value
124+
assert 0 <= trainable_backbone_layers <= max_value
124125
return trainable_backbone_layers
126+
127+
128+
def mobilenet_backbone(
129+
backbone_name,
130+
pretrained,
131+
fpn,
132+
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
133+
trainable_layers=2,
134+
returned_layers=None,
135+
extra_blocks=None
136+
):
137+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
138+
139+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
140+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
141+
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
142+
num_stages = len(stage_indices)
143+
144+
# find the index of the layer from which we wont freeze
145+
assert 0 <= trainable_layers <= num_stages
146+
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
147+
148+
# freeze layers only if pretrained backbone is used
149+
for b in backbone[:freeze_before]:
150+
for parameter in b.parameters():
151+
parameter.requires_grad_(False)
152+
153+
out_channels = 256
154+
if fpn:
155+
if extra_blocks is None:
156+
extra_blocks = LastLevelMaxPool()
157+
158+
if returned_layers is None:
159+
returned_layers = [num_stages - 2, num_stages - 1]
160+
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
161+
return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)}
162+
163+
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
164+
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
165+
else:
166+
m = nn.Sequential(
167+
backbone,
168+
# depthwise linear combination of channels to reduce their size
169+
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
170+
)
171+
m.out_channels = out_channels
172+
return m

torchvision/models/detection/faster_rcnn.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from collections import OrderedDict
2-
31
import torch
42
from torch import nn
53
import torch.nn.functional as F
64

7-
from torchvision.ops import misc as misc_nn_ops
85
from torchvision.ops import MultiScaleRoIAlign
96

107
from ._utils import overwrite_eps
@@ -15,11 +12,11 @@
1512
from .rpn import RPNHead, RegionProposalNetwork
1613
from .roi_heads import RoIHeads
1714
from .transform import GeneralizedRCNNTransform
18-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
15+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
1916

2017

2118
__all__ = [
22-
"FasterRCNN", "fasterrcnn_resnet50_fpn",
19+
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"
2320
]
2421

2522

@@ -291,6 +288,8 @@ def forward(self, x):
291288
model_urls = {
292289
'fasterrcnn_resnet50_fpn_coco':
293290
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
291+
'fasterrcnn_mobilenet_v3_large_fpn_coco':
292+
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth',
294293
}
295294

296295

@@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
353352
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
354353
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
355354
"""
356-
# check default parameters and by default set it to 3 if possible
357-
trainable_backbone_layers = _validate_resnet_trainable_layers(
358-
pretrained or pretrained_backbone, trainable_backbone_layers)
355+
trainable_backbone_layers = _validate_trainable_layers(
356+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
359357

360358
if pretrained:
361359
# no need to download the backbone if pretrained is set
@@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
368366
model.load_state_dict(state_dict)
369367
overwrite_eps(model, 0.0)
370368
return model
369+
370+
371+
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
372+
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
373+
**kwargs):
374+
"""
375+
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
376+
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
377+
378+
Example::
379+
380+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
381+
>>> model.eval()
382+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
383+
>>> predictions = model(x)
384+
385+
Args:
386+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
387+
progress (bool): If True, displays a progress bar of the download to stderr
388+
num_classes (int): number of output classes of the model (including the background)
389+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
390+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
391+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
392+
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
393+
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
394+
rpn_score_thresh (float): during inference, only return proposals with a classification score
395+
greater than rpn_score_thresh
396+
"""
397+
trainable_backbone_layers = _validate_trainable_layers(
398+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
399+
400+
if pretrained:
401+
pretrained_backbone = False
402+
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
403+
trainable_layers=trainable_backbone_layers)
404+
405+
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
406+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
407+
408+
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
409+
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
410+
if pretrained:
411+
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
412+
model.load_state_dict(state_dict)
413+
return model

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..utils import load_state_dict_from_url
88

99
from .faster_rcnn import FasterRCNN
10-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
10+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1111

1212

1313
__all__ = [
@@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
322322
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
323323
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
324324
"""
325-
# check default parameters and by default set it to 3 if possible
326-
trainable_backbone_layers = _validate_resnet_trainable_layers(
327-
pretrained or pretrained_backbone, trainable_backbone_layers)
325+
trainable_backbone_layers = _validate_trainable_layers(
326+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
328327

329328
if pretrained:
330329
# no need to download the backbone if pretrained is set

torchvision/models/detection/mask_rcnn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..utils import load_state_dict_from_url
99

1010
from .faster_rcnn import FasterRCNN
11-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
11+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1212

1313
__all__ = [
1414
"MaskRCNN", "maskrcnn_resnet50_fpn",
@@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
317317
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
318318
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
319319
"""
320-
# check default parameters and by default set it to 3 if possible
321-
trainable_backbone_layers = _validate_resnet_trainable_layers(
322-
pretrained or pretrained_backbone, trainable_backbone_layers)
320+
trainable_backbone_layers = _validate_trainable_layers(
321+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
323322

324323
if pretrained:
325324
# no need to download the backbone if pretrained is set

0 commit comments

Comments
 (0)