diff --git a/test/expect/ModelTester.test_retinanet_mobilenet_v3_large_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_mobilenet_v3_large_fpn_expect.pkl new file mode 100644 index 00000000000..2f7ca93d30b Binary files /dev/null and b/test/expect/ModelTester.test_retinanet_mobilenet_v3_large_fpn_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index dfbaf88be6c..165c2c07184 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -40,6 +40,7 @@ def get_available_video_models(): "maskrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1], + "retinanet_mobilenet_v3_large_fpn": lambda x: x[1], } @@ -104,7 +105,7 @@ def _test_detection_model(self, name, dev): kwargs = {} if "retinanet" in name: # Reduce the default threshold to ensure the returned boxes are not empty. - kwargs["score_thresh"] = 0.01 + kwargs["score_thresh"] = 0.0099999 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model.eval().to(device=dev) input_shape = (3, 300, 300) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 6d767971f72..ac551598a96 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -129,13 +129,14 @@ def test_forward_negative_sample_krcnn(self): self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) def test_forward_negative_sample_retinanet(self): - model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + for name in ["retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn"]: + model = torchvision.models.detection.__dict__[name]( + num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) - images, targets = self._make_empty_sample() - loss_dict = model(images, targets) + images, targets = self._make_empty_sample() + loss_dict = model(images, targets) - self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) + self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) if __name__ == '__main__': diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index bfb26f24eae..8af5c09b097 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=True, trainable_backbone_layers=None) + ret = backbone_utils._validate_trainable_layers( + pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) self.assertEqual(ret, 3) # can't go beyond 5 with self.assertRaises(AssertionError): - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=True, trainable_backbone_layers=6) + ret = backbone_utils._validate_trainable_layers( + pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) # if not pretrained, should use all trainable layers and warn with self.assertWarns(UserWarning): - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=False, trainable_backbone_layers=0) + ret = backbone_utils._validate_trainable_layers( + pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) self.assertEqual(ret, 5) def test_transform_copy_targets(self): diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 746e0ee2f59..b88da647a71 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,10 +1,10 @@ import warnings -from collections import OrderedDict from torch import nn from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops import misc as misc_nn_ops from .._utils import IntermediateLayerGetter +from .. import mobilenet from .. import resnet @@ -108,17 +108,55 @@ def resnet_fpn_backbone( return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) -def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers): +def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value): # dont freeze any layers if pretrained model or backbone is not used if not pretrained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has not effect if " "neither pretrained nor pretrained_backbone have been set to True, " - "falling back to trainable_backbone_layers=5 so that all layers are trainable") - trainable_backbone_layers = 5 - # by default, freeze first 2 blocks following Faster R-CNN + "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) + trainable_backbone_layers = max_value + + # by default freeze first blocks if trainable_backbone_layers is None: - trainable_backbone_layers = 3 - assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + trainable_backbone_layers = default_value + assert 0 <= trainable_backbone_layers <= max_value return trainable_backbone_layers + + +def mobilenet_fpn_backbone( + backbone_name, + pretrained, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=2, + returned_layers=None, + extra_blocks=None +): + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features + + # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] + num_stages = len(stage_indeces) + + # find the index of the layer from which we wont freeze + assert 0 <= trainable_layers <= num_stages + freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers] + + # freeze layers only if pretrained backbone is used + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [num_stages - 2, num_stages - 1] + assert min(returned_layers) >= 0 and max(returned_layers) < num_stages + return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)} + + in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers] + out_channels = 256 + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 7d896d5ec95..80ccc129f8f 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -15,7 +15,7 @@ from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ @@ -350,8 +350,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 44df04819ff..4f375f818c0 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -7,7 +7,7 @@ from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ @@ -319,8 +319,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 565ef05f4cc..8f982ef02d5 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -8,7 +8,7 @@ from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ "MaskRCNN", "maskrcnn_resnet50_fpn", @@ -314,8 +314,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 9836b3316d8..26be5c7bfa4 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -12,14 +12,14 @@ from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_fpn_backbone from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops __all__ = [ - "RetinaNet", "retinanet_resnet50_fpn", + "RetinaNet", "retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn" ] @@ -557,7 +557,10 @@ def forward(self, images, targets=None): return self.eager_outputs(losses, detections) +# TODO: replace with pytorch links model_urls = { + 'retinanet_mobilenet_v3_large_fpn_coco': + 'https://github.com/datumbox/torchvision-models/raw/main/retinanet_mobilenet_v3_large_fpn-41c847a4.pth', 'retinanet_resnet50_fpn_coco': 'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth', } @@ -606,8 +609,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set @@ -622,3 +625,44 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model + + +def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, + trainable_backbone_layers=None, **kwargs): + """ + Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly + to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details. + + Example:: + + >>> model = torchvision.models.detection.retinanet_mobilenet_v3_large_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + """ + # check default parameters and by default set it to 3 if possible + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + + if pretrained: + pretrained_backbone = False + backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5], + trainable_layers=trainable_backbone_layers) + + anchor_sizes = ((128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + + model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 990429bacf9..12f25ef495c 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -53,6 +53,7 @@ def __init__( norm_layer(out_planes), activation_layer(inplace=True) ) + self.out_channels = out_planes # necessary for backwards compatibility @@ -90,6 +91,8 @@ def __init__( norm_layer(oup), ]) self.conv = nn.Sequential(*layers) + self.out_channels = oup + self.is_strided = stride > 1 def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 6282cd45434..27b9f7e10b8 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -14,7 +14,7 @@ # TODO: add pretrained model_urls = { - "mobilenet_v3_large": None, + "mobilenet_v3_large": "https://github.com/datumbox/torchvision-models/raw/main/mobilenet_v3_large-8738ca79.pth", "mobilenet_v3_small": None, } @@ -48,12 +48,12 @@ def forward(self, input: Tensor) -> Tensor: class InvertedResidualConfig: - def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool, + def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, activation: str, stride: int, width_mult: float): self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) - self.output_channels = self.adjust_channels(output_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) self.use_se = use_se self.use_hs = activation == "HS" self.stride = stride @@ -70,7 +70,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value') - self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels layers: List[nn.Module] = [] activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU @@ -88,10 +88,12 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod layers.append(SqueezeExcitation(cnf.expanded_channels)) # project - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.output_channels, kernel_size=1, norm_layer=norm_layer, + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=Identity)) self.block = nn.Sequential(*layers) + self.out_channels = cnf.out_channels + self.is_strided = cnf.stride > 1 def forward(self, input: Tensor) -> Tensor: result = self.block(input) @@ -146,7 +148,7 @@ def __init__( layers.append(block(cnf, norm_layer)) # building last several layers - lastconv_input_channels = inverted_residual_setting[-1].output_channels + lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Hardswish))