-
Notifications
You must be signed in to change notification settings - Fork 7.1k
RetinaNet with MobileNetV3 FPN backbone #3223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3023610
550f05e
a56fe27
06e3e72
6dc1724
0419dbc
6c53bfc
75933da
cc2def8
81800cd
1a014ef
7af35c3
98a46ff
6922e41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,13 +129,14 @@ def test_forward_negative_sample_krcnn(self): | |
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) | ||
|
||
def test_forward_negative_sample_retinanet(self): | ||
model = torchvision.models.detection.retinanet_resnet50_fpn( | ||
num_classes=2, min_size=100, max_size=100) | ||
for name in ["retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn"]: | ||
model = torchvision.models.detection.__dict__[name]( | ||
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous version seemed to download the weights of the backbone unnecessarily. I fix this inplace by adding |
||
|
||
images, targets = self._make_empty_sample() | ||
loss_dict = model(images, targets) | ||
images, targets = self._make_empty_sample() | ||
loss_dict = model(images, targets) | ||
|
||
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) | ||
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
import warnings | ||
from collections import OrderedDict | ||
from torch import nn | ||
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool | ||
|
||
from torchvision.ops import misc as misc_nn_ops | ||
from .._utils import IntermediateLayerGetter | ||
from .. import mobilenet | ||
from .. import resnet | ||
|
||
|
||
|
@@ -108,17 +108,55 @@ def resnet_fpn_backbone( | |
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) | ||
|
||
|
||
def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers): | ||
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value): | ||
# dont freeze any layers if pretrained model or backbone is not used | ||
if not pretrained: | ||
if trainable_backbone_layers is not None: | ||
warnings.warn( | ||
"Changing trainable_backbone_layers has not effect if " | ||
"neither pretrained nor pretrained_backbone have been set to True, " | ||
"falling back to trainable_backbone_layers=5 so that all layers are trainable") | ||
trainable_backbone_layers = 5 | ||
# by default, freeze first 2 blocks following Faster R-CNN | ||
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) | ||
trainable_backbone_layers = max_value | ||
|
||
# by default freeze first blocks | ||
if trainable_backbone_layers is None: | ||
trainable_backbone_layers = 3 | ||
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 | ||
trainable_backbone_layers = default_value | ||
assert 0 <= trainable_backbone_layers <= max_value | ||
return trainable_backbone_layers | ||
|
||
|
||
def mobilenet_fpn_backbone( | ||
backbone_name, | ||
pretrained, | ||
norm_layer=misc_nn_ops.FrozenBatchNorm2d, | ||
trainable_layers=2, | ||
returned_layers=None, | ||
extra_blocks=None | ||
): | ||
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features | ||
|
||
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. | ||
# The first and last blocks are always included because they are the C0 (conv1) and Cn. | ||
stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of adding meta-data with the location on the blocks that downsample, I get it by checking a new attribute called Note that blocks at first and last position of the features block are always included. |
||
num_stages = len(stage_indeces) | ||
|
||
# find the index of the layer from which we wont freeze | ||
assert 0 <= trainable_layers <= num_stages | ||
freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers] | ||
|
||
# freeze layers only if pretrained backbone is used | ||
for b in backbone[:freeze_before]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unlike the resnet implementation, here we need to find the location of the first block that we finetune and mark everything before that as frozen. |
||
for parameter in b.parameters(): | ||
parameter.requires_grad_(False) | ||
|
||
if extra_blocks is None: | ||
extra_blocks = LastLevelMaxPool() | ||
|
||
if returned_layers is None: | ||
returned_layers = [num_stages - 2, num_stages - 1] | ||
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages | ||
return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)} | ||
|
||
in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers] | ||
out_channels = 256 | ||
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,14 @@ | |
from . import _utils as det_utils | ||
from .anchor_utils import AnchorGenerator | ||
from .transform import GeneralizedRCNNTransform | ||
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers | ||
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_fpn_backbone | ||
from ...ops.feature_pyramid_network import LastLevelP6P7 | ||
from ...ops import sigmoid_focal_loss | ||
from ...ops import boxes as box_ops | ||
|
||
|
||
__all__ = [ | ||
"RetinaNet", "retinanet_resnet50_fpn", | ||
"RetinaNet", "retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn" | ||
] | ||
|
||
|
||
|
@@ -557,7 +557,10 @@ def forward(self, images, targets=None): | |
return self.eager_outputs(losses, detections) | ||
|
||
|
||
# TODO: replace with pytorch links | ||
model_urls = { | ||
'retinanet_mobilenet_v3_large_fpn_coco': | ||
'https://github.com/datumbox/torchvision-models/raw/main/retinanet_mobilenet_v3_large_fpn-41c847a4.pth', | ||
'retinanet_resnet50_fpn_coco': | ||
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth', | ||
} | ||
|
@@ -606,8 +609,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, | |
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. | ||
""" | ||
# check default parameters and by default set it to 3 if possible | ||
trainable_backbone_layers = _validate_resnet_trainable_layers( | ||
pretrained or pretrained_backbone, trainable_backbone_layers) | ||
trainable_backbone_layers = _validate_trainable_layers( | ||
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) | ||
|
||
if pretrained: | ||
# no need to download the backbone if pretrained is set | ||
|
@@ -622,3 +625,44 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, | |
model.load_state_dict(state_dict) | ||
overwrite_eps(model, 0.0) | ||
return model | ||
|
||
|
||
def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, | ||
trainable_backbone_layers=None, **kwargs): | ||
""" | ||
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly | ||
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details. | ||
|
||
Example:: | ||
|
||
>>> model = torchvision.models.detection.retinanet_mobilenet_v3_large_fpn(pretrained=True) | ||
>>> model.eval() | ||
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] | ||
>>> predictions = model(x) | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
num_classes (int): number of output classes of the model (including the background) | ||
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet | ||
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. | ||
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. | ||
""" | ||
# check default parameters and by default set it to 3 if possible | ||
trainable_backbone_layers = _validate_trainable_layers( | ||
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) | ||
|
||
if pretrained: | ||
pretrained_backbone = False | ||
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5], | ||
trainable_layers=trainable_backbone_layers) | ||
|
||
anchor_sizes = ((128,), (256,), (512,)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anchor sizes for C4, C5 and pool. It's important to note that C4 and C5 have the same output stride of 32. |
||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) | ||
|
||
model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'], | ||
progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ def __init__( | |
norm_layer(out_planes), | ||
activation_layer(inplace=True) | ||
) | ||
self.out_channels = out_planes | ||
|
||
|
||
# necessary for backwards compatibility | ||
|
@@ -90,6 +91,8 @@ def __init__( | |
norm_layer(oup), | ||
]) | ||
self.conv = nn.Sequential(*layers) | ||
self.out_channels = oup | ||
self.is_strided = stride > 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meta data are added in the blocks to make it easier to detect the C1...Cn blocks and the out_channels in detection models. We do this both on mobilenetv2 and mobilenetv3. |
||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
if self.use_res_connect: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight adjustment necessary for getting non-zero results on MobileNetV3.