Skip to content

Commit c3dfca1

Browse files
committed
Adding a FasterRCNN + MobileNetV3 with & w/o FPN models.
1 parent 8348c3a commit c3dfca1

7 files changed

+145
-10
lines changed
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def get_available_video_models():
3737
'googlenet': lambda x: x.logits,
3838
'inception_v3': lambda x: x.logits,
3939
"fasterrcnn_resnet50_fpn": lambda x: x[1],
40+
"fasterrcnn_mobilenet_v3_large": lambda x: x[1],
41+
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4042
"maskrcnn_resnet50_fpn": lambda x: x[1],
4143
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4244
"retinanet_resnet50_fpn": lambda x: x[1],
@@ -105,6 +107,8 @@ def _test_detection_model(self, name, dev):
105107
if "retinanet" in name:
106108
# Reduce the default threshold to ensure the returned boxes are not empty.
107109
kwargs["score_thresh"] = 0.01
110+
elif "fasterrcnn_mobilenet" in name:
111+
kwargs["box_score_thresh"] = 0.02076
108112
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
109113
model.eval().to(device=dev)
110114
input_shape = (3, 300, 300)

test/test_models_detection_negative_samples.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,15 @@ def test_assign_targets_to_proposals(self):
9797
self.assertEqual(labels[0].dtype, torch.int64)
9898

9999
def test_forward_negative_sample_frcnn(self):
100-
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
101-
num_classes=2, min_size=100, max_size=100)
100+
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn"]:
101+
model = torchvision.models.detection.__dict__[name](
102+
num_classes=2, min_size=100, max_size=100)
102103

103-
images, targets = self._make_empty_sample()
104-
loss_dict = model(images, targets)
104+
images, targets = self._make_empty_sample()
105+
loss_dict = model(images, targets)
105106

106-
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
107-
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
107+
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
108+
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
108109

109110
def test_forward_negative_sample_mrcnn(self):
110111
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
@@ -130,7 +131,7 @@ def test_forward_negative_sample_krcnn(self):
130131

131132
def test_forward_negative_sample_retinanet(self):
132133
model = torchvision.models.detection.retinanet_resnet50_fpn(
133-
num_classes=2, min_size=100, max_size=100)
134+
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)
134135

135136
images, targets = self._make_empty_sample()
136137
loss_dict = model(images, targets)

torchvision/models/detection/backbone_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from torchvision.ops import misc as misc_nn_ops
66
from .._utils import IntermediateLayerGetter
7+
from .. import mobilenet
78
from .. import resnet
89

910

@@ -122,3 +123,50 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
122123
trainable_backbone_layers = default_value
123124
assert 0 <= trainable_backbone_layers <= max_value
124125
return trainable_backbone_layers
126+
127+
128+
def mobilenet_backbone(
129+
backbone_name,
130+
pretrained,
131+
fpn,
132+
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
133+
trainable_layers=2,
134+
returned_layers=None,
135+
extra_blocks=None
136+
):
137+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
138+
139+
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
140+
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
141+
stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
142+
num_stages = len(stage_indeces)
143+
144+
# find the index of the layer from which we wont freeze
145+
assert 0 <= trainable_layers <= num_stages
146+
freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers]
147+
148+
# freeze layers only if pretrained backbone is used
149+
for b in backbone[:freeze_before]:
150+
for parameter in b.parameters():
151+
parameter.requires_grad_(False)
152+
153+
out_channels = 256
154+
if fpn:
155+
if extra_blocks is None:
156+
extra_blocks = LastLevelMaxPool()
157+
158+
if returned_layers is None:
159+
returned_layers = [num_stages - 2, num_stages - 1]
160+
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
161+
return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)}
162+
163+
in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers]
164+
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
165+
else:
166+
m = nn.Sequential(
167+
backbone,
168+
# depthwise linear combination of channels to reduce their size
169+
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
170+
)
171+
m.out_channels = out_channels
172+
return m

torchvision/models/detection/faster_rcnn.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from .rpn import RPNHead, RegionProposalNetwork
1616
from .roi_heads import RoIHeads
1717
from .transform import GeneralizedRCNNTransform
18-
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
18+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
1919

2020

2121
__all__ = [
22-
"FasterRCNN", "fasterrcnn_resnet50_fpn",
22+
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn"
2323
]
2424

2525

@@ -291,6 +291,8 @@ def forward(self, x):
291291
model_urls = {
292292
'fasterrcnn_resnet50_fpn_coco':
293293
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
294+
'fasterrcnn_mobilenet_v3_large_coco': None,
295+
'fasterrcnn_mobilenet_v3_large_fpn_coco': None,
294296
}
295297

296298

@@ -367,3 +369,83 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
367369
model.load_state_dict(state_dict)
368370
overwrite_eps(model, 0.0)
369371
return model
372+
373+
374+
def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
375+
trainable_backbone_layers=None, **kwargs):
376+
"""
377+
Constructs a Faster R-CNN model with a MobileNetV3-Large backbone. It works similarly
378+
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
379+
380+
Example::
381+
382+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large(pretrained=True)
383+
>>> model.eval()
384+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
385+
>>> predictions = model(x)
386+
387+
Args:
388+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
389+
progress (bool): If True, displays a progress bar of the download to stderr
390+
num_classes (int): number of output classes of the model (including the background)
391+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
392+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
393+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
394+
"""
395+
trainable_backbone_layers = _validate_trainable_layers(
396+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
397+
398+
if pretrained:
399+
pretrained_backbone = False
400+
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, False,
401+
trainable_layers=trainable_backbone_layers)
402+
403+
anchor_sizes = ((32, 64, 128, 256, 512), )
404+
aspect_ratios = ((0.5, 1.0, 2.0), )
405+
406+
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
407+
**kwargs)
408+
if pretrained:
409+
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_coco'], progress=progress)
410+
model.load_state_dict(state_dict)
411+
return model
412+
413+
414+
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
415+
trainable_backbone_layers=None, **kwargs):
416+
"""
417+
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
418+
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
419+
420+
Example::
421+
422+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
423+
>>> model.eval()
424+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
425+
>>> predictions = model(x)
426+
427+
Args:
428+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
429+
progress (bool): If True, displays a progress bar of the download to stderr
430+
num_classes (int): number of output classes of the model (including the background)
431+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
432+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
433+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
434+
"""
435+
trainable_backbone_layers = _validate_trainable_layers(
436+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
437+
438+
if pretrained:
439+
pretrained_backbone = False
440+
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
441+
trainable_layers=trainable_backbone_layers)
442+
443+
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
444+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
445+
446+
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
447+
**kwargs)
448+
if pretrained:
449+
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
450+
model.load_state_dict(state_dict)
451+
return model

torchvision/models/detection/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
__all__ = [
22-
"RetinaNet", "retinanet_resnet50_fpn",
22+
"RetinaNet", "retinanet_resnet50_fpn"
2323
]
2424

2525

0 commit comments

Comments
 (0)