From 8348c3ab208264f11e14efbc43d8ab4aab3903a4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Jan 2021 13:41:59 +0000 Subject: [PATCH 01/11] Minor refactoring of a private method to make it reusuable. --- test/test_models_detection_utils.py | 12 ++++++------ torchvision/models/detection/backbone_utils.py | 14 +++++++------- torchvision/models/detection/faster_rcnn.py | 7 +++---- torchvision/models/detection/keypoint_rcnn.py | 7 +++---- torchvision/models/detection/mask_rcnn.py | 7 +++---- torchvision/models/detection/retinanet.py | 7 +++---- 6 files changed, 25 insertions(+), 29 deletions(-) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index bfb26f24eae..8af5c09b097 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=True, trainable_backbone_layers=None) + ret = backbone_utils._validate_trainable_layers( + pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) self.assertEqual(ret, 3) # can't go beyond 5 with self.assertRaises(AssertionError): - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=True, trainable_backbone_layers=6) + ret = backbone_utils._validate_trainable_layers( + pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) # if not pretrained, should use all trainable layers and warn with self.assertWarns(UserWarning): - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=False, trainable_backbone_layers=0) + ret = backbone_utils._validate_trainable_layers( + pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) self.assertEqual(ret, 5) def test_transform_copy_targets(self): diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 746e0ee2f59..6306be67dc3 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,5 +1,4 @@ import warnings -from collections import OrderedDict from torch import nn from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool @@ -108,17 +107,18 @@ def resnet_fpn_backbone( return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) -def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers): +def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value): # dont freeze any layers if pretrained model or backbone is not used if not pretrained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has not effect if " "neither pretrained nor pretrained_backbone have been set to True, " - "falling back to trainable_backbone_layers=5 so that all layers are trainable") - trainable_backbone_layers = 5 - # by default, freeze first 2 blocks following Faster R-CNN + "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) + trainable_backbone_layers = max_value + + # by default freeze first blocks if trainable_backbone_layers is None: - trainable_backbone_layers = 3 - assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + trainable_backbone_layers = default_value + assert 0 <= trainable_backbone_layers <= max_value return trainable_backbone_layers diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index e42680d682d..89beb9de83f 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -15,7 +15,7 @@ from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ @@ -353,9 +353,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ - # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 0475994a5a0..ea4a078da9d 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -7,7 +7,7 @@ from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ @@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ - # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 4f065f3f917..7dac4d0a105 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -8,7 +8,7 @@ from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ "MaskRCNN", "maskrcnn_resnet50_fpn", @@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ - # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 9836b3316d8..7289090cad0 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -12,7 +12,7 @@ from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops @@ -605,9 +605,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ - # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set From c3dfca1d305157fc476ea601dbab62ce428649d2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Jan 2021 15:11:44 +0000 Subject: [PATCH 02/11] Adding a FasterRCNN + MobileNetV3 with & w/o FPN models. --- ...t_fasterrcnn_mobilenet_v3_large_expect.pkl | Bin 0 -> 4109 bytes ...sterrcnn_mobilenet_v3_large_fpn_expect.pkl | Bin 0 -> 4109 bytes test/test_models.py | 4 + .../test_models_detection_negative_samples.py | 15 +-- .../models/detection/backbone_utils.py | 48 ++++++++++ torchvision/models/detection/faster_rcnn.py | 86 +++++++++++++++++- torchvision/models/detection/retinanet.py | 2 +- 7 files changed, 145 insertions(+), 10 deletions(-) create mode 100644 test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl create mode 100644 test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..5f2a44b2be4a39ce42594ed687fbcee0df541f55 GIT binary patch literal 4109 zcmeHKZA?>V6uvFxt4!{dmFjGjWyxY;C&&PZ3YnRh>9G8{FJ+TBXk8=iH?cUv<61Ow#TpxJbS109iI~S~d+Fc5_ zN6}au>A=wH@9~5btt;#Y!}Tp)VXrTwZwY!@ySzSkOW4yH@&{XXm+6%>i^Y)%1|iUN zws{ntYND_6`JLg{)Iaq~x`PLu4c5}KP1ee?%1s+8O0A_zh9eFBbp^VV%vWQYu#(j% z3$j7h9aOU6ra7qS2VE||&*yQ4z5dRSerwR}33}X{y{@oc$ti47o(U_dT)CVdnuc+LpJg0K4Vq7N5WgbMX)nR2#lTy&-TyK{%v)-F(+vwGB z9KGy|7S7BqTTr-?xddYDbyD_R5lbKkbS>SN^03qgPtE}yyQq987 zjKn3XKX^S#w(jK##T+*R@+dj@`B`bSakKdQgD4R$$fPu4BKtn_)lhlfMH9+zm6-{Z z-v|Dxn-!whAOgP(`hMi+{;a6kcDJ0obbMOkdpjiVt(`Sj?)4JyMV{mxSxbH_w1VCU zvHu0}ory{Tla9<3C#8C$l?){O+FxdZU2yfh%#5$`2REq+WZC*rth)Jz7D! z^zaDDn&Jt?$2be2SomsMQn7$IYpNFIM>z+>CZ=yiTrkX|exEjgeA!bjrr*c-PN2L- zGs)yKles^)eh`{Fhn0I>i0A_A_~m z$tO^5( zDg6-SxOgHW8K?^o-&?u}}SKW-GDPdM)7y#neNS;i25P#Z&Ru_O^+D6pWt z(coZwHiiR;&lmI|e*JIgpKs2J=R0iD4AgeGnD!yZDJC=dx!DH&&_Hs}T$Zj?#-tvj znG8b@Q2Et|qDbSh+XINdEzlvK!C-}cIRW}(ps#~^=*Vw(hR7@02>2Z#SD`NAKKb`ZcKEbYCP?zels5;X=eoZ70La z8QzX~%Z+ZdpV4ZWyqVcY$w;CX`-R3=Z&dR$t{2GAwOVoIgSDvd$e)x{oQ8N@Z(;IU z#M{&R5g&RKMSQv_fcm9(PhkAxdK=T?fq1_*!u}TL+hlne|bO6zNi-YQMlq{)zbC<$pP($)$iWk z6@w!Jy%z5Wo-3d2huxlF2+m4%Qpta9fpR_d`2B}%+q5vb^s>tudYe5{&OfvWRIp&B zo-k9b=YM;d7g&A@PdDE>^_Y{&G=EAKGL^#7brB}~a!X~JpWlT{7jqV6!Wm0tnx8j% z)SX+uESqCu)v?PN8O2OvP1?FyU3&1Q>+!uA5<^#0u6hSMder7lF%ML1m(%I#7i`iw Is;}z)10hb;YXATM literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3d2979297dc5035e64ae7ae641d8afad0d6a4232 GIT binary patch literal 4109 zcmeHKZERCj7{2@HTDk#cn=lc!MH;|tY_pUx6IZD9+v-5FOh_hnN% zXy&XA*eF70tUsJEb3;~)Hr51tt)6JY8H|TJdwZf?K`Xp35sL@=JWexj_3D6~HY8}J z(H%DPw^go(AE{0Ke?)Wvrzt))R=eY)HgfZkgqNv!d6i@#%KQA3sLU z_vC0AZr`lJEz{;_2v<8ppvc$4=FNz>;oS?3bNBy;1IYaKr5)BVNsRkZGOwIXX3-}t z?8zCKGkP|CFp_$fYep5n?cT_((MCR!+N8!lH4f64(|m-+9emCBIi?LPXBXSX*!7d0 ztY-W&)2@BN#@fDQKehkNPNizOHn4?FxEuIn`%xON;M%nj8duZ!id=8lIo4aFWt5Nd z4;*C3KiU2P6Y-61j`-GSKR=Z!r2KY1;Vz~8WxQtmp~yGkc5vkT%zLg`GY;|3SKe*D z>K{`0?SvQj5&rs<(niiw_h2I;xlO(me8EUeseGXgo3jT$@n1{ZkH{pfez`JJr2=V(O>35py zk8=I~5t;|yV*;5=-99?%zh>^>V5 zdR%Nv3O%HLTH} z?1e^$`zT-Z*Jbh}>RGSqX%&0{AM^!2=_juI)S>)jS9st*>Yr=B6CU(W5%G(Dxg+`o z{G)#00bbw%U$XvVqJHF;{GS&53p}I`_=P_D{WlfAZwr2ti|cuGRW*0DzRERig#1AL zM|`0Fz{kVtPuRCds+p^GDc|wU8S)e5Cp970zl;n1MSmtqU-T#R zM}I{{$X~!t)!uyQ-A%Zs3;fxQd52J#Pj}{juGOp|| z%kFh}9d}c=SIcJR|&baJI>Ac}n zcjme`e1P%RI3%`d26QE;=>!))bB1L2A6+<{H1;P)Yd*$88GeBZTR}7JvI&N H<=Vdi39CoU literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index dfbaf88be6c..36b05a38c6f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -37,6 +37,8 @@ def get_available_video_models(): 'googlenet': lambda x: x.logits, 'inception_v3': lambda x: x.logits, "fasterrcnn_resnet50_fpn": lambda x: x[1], + "fasterrcnn_mobilenet_v3_large": lambda x: x[1], + "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], "maskrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1], @@ -105,6 +107,8 @@ def _test_detection_model(self, name, dev): if "retinanet" in name: # Reduce the default threshold to ensure the returned boxes are not empty. kwargs["score_thresh"] = 0.01 + elif "fasterrcnn_mobilenet" in name: + kwargs["box_score_thresh"] = 0.02076 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model.eval().to(device=dev) input_shape = (3, 300, 300) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 650a565cdea..1eab8b72d08 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -97,14 +97,15 @@ def test_assign_targets_to_proposals(self): self.assertEqual(labels[0].dtype, torch.int64) def test_forward_negative_sample_frcnn(self): - model = torchvision.models.detection.fasterrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn"]: + model = torchvision.models.detection.__dict__[name]( + num_classes=2, min_size=100, max_size=100) - images, targets = self._make_empty_sample() - loss_dict = model(images, targets) + images, targets = self._make_empty_sample() + loss_dict = model(images, targets) - self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) - self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) + self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) + self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) def test_forward_negative_sample_mrcnn(self): model = torchvision.models.detection.maskrcnn_resnet50_fpn( @@ -130,7 +131,7 @@ def test_forward_negative_sample_krcnn(self): def test_forward_negative_sample_retinanet(self): model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) images, targets = self._make_empty_sample() loss_dict = model(images, targets) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 6306be67dc3..6290f0373c2 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -4,6 +4,7 @@ from torchvision.ops import misc as misc_nn_ops from .._utils import IntermediateLayerGetter +from .. import mobilenet from .. import resnet @@ -122,3 +123,50 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, trainable_backbone_layers = default_value assert 0 <= trainable_backbone_layers <= max_value return trainable_backbone_layers + + +def mobilenet_backbone( + backbone_name, + pretrained, + fpn, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=2, + returned_layers=None, + extra_blocks=None +): + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features + + # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] + num_stages = len(stage_indeces) + + # find the index of the layer from which we wont freeze + assert 0 <= trainable_layers <= num_stages + freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers] + + # freeze layers only if pretrained backbone is used + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + out_channels = 256 + if fpn: + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [num_stages - 2, num_stages - 1] + assert min(returned_layers) >= 0 and max(returned_layers) < num_stages + return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)} + + in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers] + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) + else: + m = nn.Sequential( + backbone, + # depthwise linear combination of channels to reduce their size + nn.Conv2d(backbone[-1].out_channels, out_channels, 1), + ) + m.out_channels = out_channels + return m diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 89beb9de83f..31845a598a3 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -15,11 +15,11 @@ from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone __all__ = [ - "FasterRCNN", "fasterrcnn_resnet50_fpn", + "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn" ] @@ -291,6 +291,8 @@ def forward(self, x): model_urls = { 'fasterrcnn_resnet50_fpn_coco': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', + 'fasterrcnn_mobilenet_v3_large_coco': None, + 'fasterrcnn_mobilenet_v3_large_fpn_coco': None, } @@ -367,3 +369,83 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model + + +def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, + trainable_backbone_layers=None, **kwargs): + """ + Constructs a Faster R-CNN model with a MobileNetV3-Large backbone. It works similarly + to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. + + Example:: + + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + """ + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + + if pretrained: + pretrained_backbone = False + backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, False, + trainable_layers=trainable_backbone_layers) + + anchor_sizes = ((32, 64, 128, 256, 512), ) + aspect_ratios = ((0.5, 1.0, 2.0), ) + + model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), + **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_coco'], progress=progress) + model.load_state_dict(state_dict) + return model + + +def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, + trainable_backbone_layers=None, **kwargs): + """ + Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly + to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. + + Example:: + + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + """ + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + + if pretrained: + pretrained_backbone = False + backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, + trainable_layers=trainable_backbone_layers) + + anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + + model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), + **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 7289090cad0..60238485f60 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -19,7 +19,7 @@ __all__ = [ - "RetinaNet", "retinanet_resnet50_fpn", + "RetinaNet", "retinanet_resnet50_fpn" ] From 30299cca1e32c732f66d3fe1b1329b06732a6df5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Jan 2021 16:40:42 +0000 Subject: [PATCH 03/11] Reducing Resolution to 320-640 and anchor sizes to 16-256. --- ...t_fasterrcnn_mobilenet_v3_large_expect.pkl | Bin 4109 -> 4109 bytes ...sterrcnn_mobilenet_v3_large_fpn_expect.pkl | Bin 4109 -> 4109 bytes torchvision/models/detection/faster_rcnn.py | 16 ++++++++++------ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl index 5f2a44b2be4a39ce42594ed687fbcee0df541f55..9a9207b73605db38072db7b2a7b5bfaf8a079b65 100644 GIT binary patch literal 4109 zcmeHKeN0nV6n{`!pwPaT=!*r^dt}8)l2Q^%0$jw? zQf`w&Scp2$?y$JuV}Gh;WyuuclVM0-Xv{Pkvy2((nG3UIRf&T9Yn7_x(1PY5w;Wbv zjxxuZL+a#katW%F)eTmw!)~`(-Ib2@F7><4GMm#@wzSgfR?88oq8#a#HKMF7Q4t!u zWsS}5lBZM9&u$`u~ zk;)*BgR^kL!Ra{7xWT3Q8p!hHq42RGrp6!6o;?3QJ%BM!|Lno4W-?+pUo__A3C3)s zw`d+p8XNPtHR})3$QpoS@;cdv$@}2hbpY!z`{V4>v&cIK&KIpmexmhA-&=Vo`Z?S%)BCgz|K8=m;a?v^zz25Vd;7+Rc$=aYpW|wA_~s^pRTTfi&7pYD z$!^5#p%g>D^uY!M`5q45K!8_2no9Ab!LKMzA3RI(V6_G(f261JO+l>)`1jpeCjg&L zVW1f5IWyOQ?;ZCD2~#CnpXz{w0P8bBC&x7?XXXI5l}VJpt(v3t4pqNO>)&)WQ~*7H zJudLtv%U6mYHzEa;l;}+f7e`*=4*J9zQ_QjnSx|UFTY-EVn9!~Aik*g_=xpjrwqa$c!FsB3XO&0KcCc4yydbafPK?lJ?d*u^~UG;;*U#RK?r|(SUo;$6>(nI+4U1aDHvS z0RHn$oy5-5_&*HcIl%99NC@H~AB$)C3X*u5uP0F?e0t3COTPRJ8z^S^Ko8{O@jo?j z*mzSy#RGM?H3IbLF{@Xi_4b}GqI_y?J&lL_V2AwB58z|`qL*KfE29U5c%6iQGyADEAbJ;VJdQSuOe2i_AQ(Rx{bCE8z@AH%#4vwC=156qj* z3-}DQ9qL^0!M80rH&dOrb%e1IXI<%4*Uu9t(2TAY8>lZ$Jn zQ=A)f2m#(c)J2~M2OH;k?XOdNa6zX4_J}8O$TNHlIZqD=9k+|fJb7vtVs6p+9YZ51 zhCGk+wXG<5WHvtXw431BUW_PCKG}lWbOy9+s2d%5+JchF^V+Pvh~|u_@bY05hWFD? zhf~^f$z}%5iQi9CIP7~rZLm3AWN(y*mOST`nM1f)45xBz)jQ0N)){ToiLB5XtqXV}&flU5y@$pPwDkfrLU($h0{{0@$luq`v z&)3BEu7OPceS@v;*l|HK#^jw=p;@Jx=k=JOnD6t2jW-;l*Y`V6uvFxt4!{dmFjGjWyxY;C&&PZ3YnRh>9G8{FJ+TBXk8=iH?cUv<61Ow#TpxJbS109iI~S~d+Fc5_ zN6}au>A=wH@9~5btt;#Y!}Tp)VXrTwZwY!@ySzSkOW4yH@&{XXm+6%>i^Y)%1|iUN zws{ntYND_6`JLg{)Iaq~x`PLu4c5}KP1ee?%1s+8O0A_zh9eFBbp^VV%vWQYu#(j% z3$j7h9aOU6ra7qS2VE||&*yQ4z5dRSerwR}33}X{y{@oc$ti47o(U_dT)CVdnuc+LpJg0K4Vq7N5WgbMX)nR2#lTy&-TyK{%v)-F(+vwGB z9KGy|7S7BqTTr-?xddYDbyD_R5lbKkbS>SN^03qgPtE}yyQq987 zjKn3XKX^S#w(jK##T+*R@+dj@`B`bSakKdQgD4R$$fPu4BKtn_)lhlfMH9+zm6-{Z z-v|Dxn-!whAOgP(`hMi+{;a6kcDJ0obbMOkdpjiVt(`Sj?)4JyMV{mxSxbH_w1VCU zvHu0}ory{Tla9<3C#8C$l?){O+FxdZU2yfh%#5$`2REq+WZC*rth)Jz7D! z^zaDDn&Jt?$2be2SomsMQn7$IYpNFIM>z+>CZ=yiTrkX|exEjgeA!bjrr*c-PN2L- zGs)yKles^)eh`{Fhn0I>i0A_A_~m z$tO^5( zDg6-SxOgHW8K?^o-&?u}}SKW-GDPdM)7y#neNS;i25P#Z&Ru_O^+D6pWt z(coZwHiiR;&lmI|e*JIgpKs2J=R0iD4AgeGnD!yZDJC=dx!DH&&_Hs}T$Zj?#-tvj znG8b@Q2Et|qDbSh+XINdEzlvK!C-}cIRW}(ps#~^=*Vw(hR7@02>2Z#SD`NAKKb`ZcKEbYCP?zels5;X=eoZ70La z8QzX~%Z+ZdpV4ZWyqVcY$w;CX`-R3=Z&dR$t{2GAwOVoIgSDvd$e)x{oQ8N@Z(;IU z#M{&R5g&RKMSQv_fcm9(PhkAxdK=T?fq1_*!u}TL+hlne|bO6zNi-YQMlq{)zbC<$pP($)$iWk z6@w!Jy%z5Wo-3d2huxlF2+m4%Qpta9fpR_d`2B}%+q5vb^s>tudYe5{&OfvWRIp&B zo-k9b=YM;d7g&A@PdDE>^_Y{&G=EAKGL^#7brB}~a!X~JpWlT{7jqV6!Wm0tnx8j% z)SX+uESqCu)v?PN8O2OvP1?FyU3&1Q>+!uA5<^#0u6hSMder7lF%ML1m(%I#7i`iw Is;}z)10hb;YXATM diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl index 3d2979297dc5035e64ae7ae641d8afad0d6a4232..9be3d021af619dcdb574cf36c78bcbffa44b5c8c 100644 GIT binary patch literal 4109 zcmeHKZ){Ul6o2jhcXTTQLD(E~1F}UNqinc^TDN;m7wCp%3L9inyRNQ(yAs^_v*rG@>;ve zKHDDOrKX67u1huSak~S4zsDW+1)4*K7lMtRpr`S9pF3<&QwwU;8DZ5}qo&m*3O9aN zgU26I)4d`i?^qU9f##<1Bup_$6c;Zum#QxWOYvgQ2@Da%-pXKg!rsJ>(C1 z_#4w-xnyI|8&38x#4A4K(c*2R{G|PP&ECXe9~sCb`-k2tkzXp04AhGFIedHZE+H52 zHP^d^Y$X}?kY#)m)dPE3kBLJM8G7W8_!6IOBHt=aUrYKQD=${Szb}#|@Uv1qz(eAr zC;XI~w=KZ$d;190<7qfTdc(=fLN;;8hzI$%qZ>55ItZ_Giw`lyb=cD7%H_j12T2e7 zOR76Wyn2z}UC}?tqa>pr5|1*%;~VE3j`*c{yE)>c9>haEGCuIC)bK(-As1+}l_P%S zX9J`7oz+6V#V%OhE$?!b^Wn~=+&Rm@rGAr;tEqnAhx&ovA9G#&GacuLX71pxrMQH? zP{Xf`^heK+uwml}Yq7QPUl+9qy@{j$5|18+`GX$#L67`_uk44OOZ+P}{Oh@?_f>1$ zm3i#k>0ZWG9wofwymXAui>wd*w(`+PZ$0juQU27k=B-kjw|dGCcw;_*x9k`Ag??G79@#(e zPx49T2R>Q(8`b6PtI$ri+Z8YH-9q`vd4&Boj(L%IU>@`&OFl!E^AEk8f0{=N@)3BJ zkIkFpTL3i8&U{eeoR!aIyw{13n19R<>cjkC9^|~NqIr>ce5>J6j{U%b{sMowUleNmL%${e zB!6p(zs>tCn!hIUC+9s?!!L*G*%91I^$cBmi@$O%(F#5q^9mUFcxPcZTXO8pivF5A z?4`G_GVuAMGo>QmT*?>tT?_q5{k?zT0Dth)40il%4_kb?n)Ij-c%VMuf%%pD2lnU7 z4DYM`A7=etE21FP-+_X+F!s2$=b$E=QDWeZ9CMYId zu_w$}>*#OctVx!iim&5n>)2yXtf%N*YO<$(fy)%0=u0}*Q#78FJvj_hdZIlp)>AZY zIO&lRCQUwt9gf<{`&tC39rgsAeuYhf Ig7tFk-{vS_X8-^I literal 4109 zcmeHKZERCj7{2@HTDk#cn=lc!MH;|tY_pUx6IZD9+v-5FOh_hnN% zXy&XA*eF70tUsJEb3;~)Hr51tt)6JY8H|TJdwZf?K`Xp35sL@=JWexj_3D6~HY8}J z(H%DPw^go(AE{0Ke?)Wvrzt))R=eY)HgfZkgqNv!d6i@#%KQA3sLU z_vC0AZr`lJEz{;_2v<8ppvc$4=FNz>;oS?3bNBy;1IYaKr5)BVNsRkZGOwIXX3-}t z?8zCKGkP|CFp_$fYep5n?cT_((MCR!+N8!lH4f64(|m-+9emCBIi?LPXBXSX*!7d0 ztY-W&)2@BN#@fDQKehkNPNizOHn4?FxEuIn`%xON;M%nj8duZ!id=8lIo4aFWt5Nd z4;*C3KiU2P6Y-61j`-GSKR=Z!r2KY1;Vz~8WxQtmp~yGkc5vkT%zLg`GY;|3SKe*D z>K{`0?SvQj5&rs<(niiw_h2I;xlO(me8EUeseGXgo3jT$@n1{ZkH{pfez`JJr2=V(O>35py zk8=I~5t;|yV*;5=-99?%zh>^>V5 zdR%Nv3O%HLTH} z?1e^$`zT-Z*Jbh}>RGSqX%&0{AM^!2=_juI)S>)jS9st*>Yr=B6CU(W5%G(Dxg+`o z{G)#00bbw%U$XvVqJHF;{GS&53p}I`_=P_D{WlfAZwr2ti|cuGRW*0DzRERig#1AL zM|`0Fz{kVtPuRCds+p^GDc|wU8S)e5Cp970zl;n1MSmtqU-T#R zM}I{{$X~!t)!uyQ-A%Zs3;fxQd52J#Pj}{juGOp|| z%kFh}9d}c=SIcJR|&baJI>Ac}n zcjme`e1P%RI3%`d26QE;=>!))bB1L2A6+<{H1;P)Yd*$88GeBZTR}7JvI&N H<=Vdi39CoU diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 31845a598a3..1529298b1bc 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -372,7 +372,7 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, **kwargs): + trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs): """ Constructs a Faster R-CNN model with a MobileNetV3-Large backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. @@ -391,6 +391,8 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9 pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone """ trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) @@ -400,11 +402,11 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9 backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, False, trainable_layers=trainable_backbone_layers) - anchor_sizes = ((32, 64, 128, 256, 512), ) + anchor_sizes = ((16, 32, 64, 128, 256), ) aspect_ratios = ((0.5, 1.0, 2.0), ) model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - **kwargs) + min_size=min_size, max_size=max_size, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_coco'], progress=progress) model.load_state_dict(state_dict) @@ -412,7 +414,7 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9 def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, **kwargs): + trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs): """ Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. @@ -431,6 +433,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone """ trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) @@ -440,11 +444,11 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers) - anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 + anchor_sizes = ((16, 32, 64, 128, 256, ), ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - **kwargs) + min_size=min_size, max_size=max_size, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress) model.load_state_dict(state_dict) From ba96894a894232ee5c2da34504f059f0d61dbfb6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 15 Jan 2021 15:24:07 +0000 Subject: [PATCH 04/11] Increase anchor sizes. --- torchvision/models/detection/faster_rcnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 1529298b1bc..54001fb4f76 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -402,7 +402,7 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9 backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, False, trainable_layers=trainable_backbone_layers) - anchor_sizes = ((16, 32, 64, 128, 256), ) + anchor_sizes = ((32, 64, 128, 256, 512, ), ) aspect_ratios = ((0.5, 1.0, 2.0), ) model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), @@ -444,7 +444,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers) - anchor_sizes = ((16, 32, 64, 128, 256, ), ) * 3 + anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), From 217e5b10d5f4f56cf65faa493bc3e1dda8a34082 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 15 Jan 2021 21:42:47 +0000 Subject: [PATCH 05/11] Adding rpn score threshold param on the train script. --- references/detection/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/references/detection/train.py b/references/detection/train.py index f3fe9bc9fff..986abe5a715 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -94,7 +94,7 @@ def main(args): print("Creating model") kwargs = {} if "rcnn" in args.model: - kwargs["rpn_score_thresh"] = 0.0 + kwargs["rpn_score_thresh"] = args.rpn_score_thresh model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, **kwargs) model.to(device) @@ -177,6 +177,7 @@ def main(args): parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) + parser.add_argument('--rpn-score-thresh', default=0.0, type=float, help='rpn score threshold for faster-rcnn') parser.add_argument( "--test-only", dest="test_only", From 690ee5586677a1c75248598c1a3538583615ecbc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 16 Jan 2021 18:35:51 +0000 Subject: [PATCH 06/11] Adding trainable_backbone_layers param on the train script. --- references/detection/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/references/detection/train.py b/references/detection/train.py index 986abe5a715..c39ae3e723c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -92,7 +92,9 @@ def main(args): collate_fn=utils.collate_fn) print("Creating model") - kwargs = {} + kwargs = { + "trainable_backbone_layers": args.trainable_backbone_layers + } if "rcnn" in args.model: kwargs["rpn_score_thresh"] = args.rpn_score_thresh model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, @@ -178,6 +180,8 @@ def main(args): parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) parser.add_argument('--rpn-score-thresh', default=0.0, type=float, help='rpn score threshold for faster-rcnn') + parser.add_argument('--trainable-backbone-layers', default=None, type=int, + help='number of trainable layers of backbone ') parser.add_argument( "--test-only", dest="test_only", From 24ecd45a59eb17bcc5896df2d5ab3678dfe54de6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 17 Jan 2021 22:12:16 +0000 Subject: [PATCH 07/11] Adding rpn_score_thresh param directly in fasterrcnn_mobilenet_v3_large_fpn. --- references/detection/train.py | 7 ++++--- torchvision/models/detection/faster_rcnn.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index c39ae3e723c..7aa71314230 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -96,7 +96,8 @@ def main(args): "trainable_backbone_layers": args.trainable_backbone_layers } if "rcnn" in args.model: - kwargs["rpn_score_thresh"] = args.rpn_score_thresh + if args.rpn_score_thresh is not None: + kwargs["rpn_score_thresh"] = args.rpn_score_thresh model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, **kwargs) model.to(device) @@ -179,9 +180,9 @@ def main(args): parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) - parser.add_argument('--rpn-score-thresh', default=0.0, type=float, help='rpn score threshold for faster-rcnn') + parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn') parser.add_argument('--trainable-backbone-layers', default=None, type=int, - help='number of trainable layers of backbone ') + help='number of trainable layers of backbone') parser.add_argument( "--test-only", dest="test_only", diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 54001fb4f76..8b1d6952271 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -414,7 +414,8 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9 def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs): + trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05, + **kwargs): """ Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. @@ -435,6 +436,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + rpn_score_thresh (float): during inference, only return proposals with a classification score + greater than rpn_score_thresh """ trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) @@ -448,7 +451,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - min_size=min_size, max_size=max_size, **kwargs) + min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress) model.load_state_dict(state_dict) From 11408d6ce334734774c9e2bf511048cca2625e9e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 17 Jan 2021 22:51:01 +0000 Subject: [PATCH 08/11] Remove fasterrcnn_mobilenet_v3_large prototype and update expected file. --- ...t_fasterrcnn_mobilenet_v3_large_expect.pkl | Bin 4109 -> 0 bytes ...sterrcnn_mobilenet_v3_large_fpn_expect.pkl | Bin 4109 -> 4109 bytes test/test_models.py | 1 - .../test_models_detection_negative_samples.py | 2 +- torchvision/models/detection/faster_rcnn.py | 50 +----------------- 5 files changed, 3 insertions(+), 50 deletions(-) delete mode 100644 test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_expect.pkl deleted file mode 100644 index 9a9207b73605db38072db7b2a7b5bfaf8a079b65..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4109 zcmeHKeN0nV6n{`!pwPaT=!*r^dt}8)l2Q^%0$jw? zQf`w&Scp2$?y$JuV}Gh;WyuuclVM0-Xv{Pkvy2((nG3UIRf&T9Yn7_x(1PY5w;Wbv zjxxuZL+a#katW%F)eTmw!)~`(-Ib2@F7><4GMm#@wzSgfR?88oq8#a#HKMF7Q4t!u zWsS}5lBZM9&u$`u~ zk;)*BgR^kL!Ra{7xWT3Q8p!hHq42RGrp6!6o;?3QJ%BM!|Lno4W-?+pUo__A3C3)s zw`d+p8XNPtHR})3$QpoS@;cdv$@}2hbpY!z`{V4>v&cIK&KIpmexmhA-&=Vo`Z?S%)BCgz|K8=m;a?v^zz25Vd;7+Rc$=aYpW|wA_~s^pRTTfi&7pYD z$!^5#p%g>D^uY!M`5q45K!8_2no9Ab!LKMzA3RI(V6_G(f261JO+l>)`1jpeCjg&L zVW1f5IWyOQ?;ZCD2~#CnpXz{w0P8bBC&x7?XXXI5l}VJpt(v3t4pqNO>)&)WQ~*7H zJudLtv%U6mYHzEa;l;}+f7e`*=4*J9zQ_QjnSx|UFTY-EVn9!~Aik*g_=xpjrwqa$c!FsB3XO&0KcCc4yydbafPK?lJ?d*u^~UG;;*U#RK?r|(SUo;$6>(nI+4U1aDHvS z0RHn$oy5-5_&*HcIl%99NC@H~AB$)C3X*u5uP0F?e0t3COTPRJ8z^S^Ko8{O@jo?j z*mzSy#RGM?H3IbLF{@Xi_4b}GqI_y?J&lL_V2AwB58z|`qL*KfE29U5c%6iQGyADEAbJ;VJdQSuOe2i_AQ(Rx{bCE8z@AH%#4vwC=156qj* z3-}DQ9qL^0!M80rH&dOrb%e1IXI<%4*Uu9t(2TAY8>lZ$Jn zQ=A)f2m#(c)J2~M2OH;k?XOdNa6zX4_J}8O$TNHlIZqD=9k+|fJb7vtVs6p+9YZ51 zhCGk+wXG<5WHvtXw431BUW_PCKG}lWbOy9+s2d%5+JchF^V+Pvh~|u_@bY05hWFD? zhf~^f$z}%5iQi9CIP7~rZLm3AWN(y*mOST`nM1f)45xBz)jQ0N)){ToiLB5XtqXV}&flU5y@$pPwDkfrLU($h0{{0@$luq`v z&)3BEu7OPceS@v;*l|HK#^jw=p;@Jx=k=JOnD6t2jW-;l*Y`jFdVl1CqJameMs{L!~|Kw`JjqhdW4t(vvn4`e&H3m+7jz zv$=&`k-HTViE7p4Z)lfKY_QnFc`a|JfuNv4BkHowKjb>)3uI@H+g*e~o3kF!0 zR55piBR+MT{L{>`yBG024bAnyP#`W9K2A%1m3Gh0dkU ztWlQBAG%RyKH&F9!r_2l4Mq0E&0AupG_u8gMrW_emTYxqIhXEZMb7MfZ0Rd~F#nLsmUXd* zh55KY661(0m*)F26Kbh%ZE9gFq^)Aj)=JnaIUsDUQdwyiE9>{wx|~LnzvQn_NT(>6 zu6cLE@s=7VUioR4PFCmp0>j^?zU zO%?7XPLM=Qp{F{gHvE$lb(Gh=ZFZtGse9u-MB}(xRwozXNy>aUofv@Br`>oVlWTxN6 zKO>_)5551zh*l3atsa6R9^eE1fRDr%cmQ9A)(^lR{UPgJ#)Ep7@g!nBVn6B=e93&l zmqp_b{RLUpBi3a-T4;32&<6iM?tH90Sr5Pu`+?uBcRsi8Ts}drj@={R^ef}3jo+L; zLaHX}=$XwG0>7E*Dhm8gZyw_Ooh87t=*kEMp1|krn`WMW1--McQ^UVU;E#CGU+@Qh z;ScqI_+))p1b$NABlNK@`E`r^=r6>Ne3JM{ec)#i@y<-=QW>vD_ya$(KdjXB=t_!w zkQ?$5_4h%E4fVJE_i6(EB)=ZPFYp5Xzzg`J{$xFa4_VK$KHQ={yhYVQf4^@?=*#+3 zwEDA(`0A(qB0j|MV&_JZxSi9s-|IRJk83_TPtcE8M|~h)c|M_^;UD~?{)g!$ zFX!p9$F_ELGtNs<&sYbaSRac1Ne*w@K#H!s$@9&kMfvOKZ~2w&u#JiHLE6LlGc5dr zU-$>VGJi$O->TtRs^RGn^`@*IUk`sHAHUxYd=sWo8*HC=?GQo!;76WUvObZIthZ}q zpzK#NVOl0)N#-m234B_GfABBs^(4Xj%6@gjxsCi|3iCnltCo4+SNNprAdUaweU%xJ*vHksk! z;_=`PJ$H-ga_aDyy;G)MS_-j{HZpIRncAt}!X;^%pYpHcRPD@bPNt>QQ!3rkILBok zOX5p9(^4v)>6TpP`7Pm$%e0h=8?L%@~dz~_RDp$73`OL{{}@aF;4&h literal 4109 zcmeHKZ){Ul6o2jhcXTTQLD(E~1F}UNqinc^TDN;m7wCp%3L9inyRNQ(yAs^_v*rG@>;ve zKHDDOrKX67u1huSak~S4zsDW+1)4*K7lMtRpr`S9pF3<&QwwU;8DZ5}qo&m*3O9aN zgU26I)4d`i?^qU9f##<1Bup_$6c;Zum#QxWOYvgQ2@Da%-pXKg!rsJ>(C1 z_#4w-xnyI|8&38x#4A4K(c*2R{G|PP&ECXe9~sCb`-k2tkzXp04AhGFIedHZE+H52 zHP^d^Y$X}?kY#)m)dPE3kBLJM8G7W8_!6IOBHt=aUrYKQD=${Szb}#|@Uv1qz(eAr zC;XI~w=KZ$d;190<7qfTdc(=fLN;;8hzI$%qZ>55ItZ_Giw`lyb=cD7%H_j12T2e7 zOR76Wyn2z}UC}?tqa>pr5|1*%;~VE3j`*c{yE)>c9>haEGCuIC)bK(-As1+}l_P%S zX9J`7oz+6V#V%OhE$?!b^Wn~=+&Rm@rGAr;tEqnAhx&ovA9G#&GacuLX71pxrMQH? zP{Xf`^heK+uwml}Yq7QPUl+9qy@{j$5|18+`GX$#L67`_uk44OOZ+P}{Oh@?_f>1$ zm3i#k>0ZWG9wofwymXAui>wd*w(`+PZ$0juQU27k=B-kjw|dGCcw;_*x9k`Ag??G79@#(e zPx49T2R>Q(8`b6PtI$ri+Z8YH-9q`vd4&Boj(L%IU>@`&OFl!E^AEk8f0{=N@)3BJ zkIkFpTL3i8&U{eeoR!aIyw{13n19R<>cjkC9^|~NqIr>ce5>J6j{U%b{sMowUleNmL%${e zB!6p(zs>tCn!hIUC+9s?!!L*G*%91I^$cBmi@$O%(F#5q^9mUFcxPcZTXO8pivF5A z?4`G_GVuAMGo>QmT*?>tT?_q5{k?zT0Dth)40il%4_kb?n)Ij-c%VMuf%%pD2lnU7 z4DYM`A7=etE21FP-+_X+F!s2$=b$E=QDWeZ9CMYId zu_w$}>*#OctVx!iim&5n>)2yXtf%N*YO<$(fy)%0=u0}*Q#78FJvj_hdZIlp)>AZY zIO&lRCQUwt9gf<{`&tC39rgsAeuYhf Ig7tFk-{vS_X8-^I diff --git a/test/test_models.py b/test/test_models.py index 36b05a38c6f..232a78234b9 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -37,7 +37,6 @@ def get_available_video_models(): 'googlenet': lambda x: x.logits, 'inception_v3': lambda x: x.logits, "fasterrcnn_resnet50_fpn": lambda x: x[1], - "fasterrcnn_mobilenet_v3_large": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], "maskrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1], diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 1eab8b72d08..cb35f35894b 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -97,7 +97,7 @@ def test_assign_targets_to_proposals(self): self.assertEqual(labels[0].dtype, torch.int64) def test_forward_negative_sample_frcnn(self): - for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn"]: + for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]: model = torchvision.models.detection.__dict__[name]( num_classes=2, min_size=100, max_size=100) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 8b1d6952271..ae5150205a7 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,10 +1,7 @@ -from collections import OrderedDict - import torch from torch import nn import torch.nn.functional as F -from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign from ._utils import overwrite_eps @@ -19,7 +16,7 @@ __all__ = [ - "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large", "fasterrcnn_mobilenet_v3_large_fpn" + "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn" ] @@ -291,8 +288,7 @@ def forward(self, x): model_urls = { 'fasterrcnn_resnet50_fpn_coco': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', - 'fasterrcnn_mobilenet_v3_large_coco': None, - 'fasterrcnn_mobilenet_v3_large_fpn_coco': None, + 'fasterrcnn_mobilenet_v3_large_fpn_coco': None, # TODO: Add the final model url } @@ -371,48 +367,6 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, return model -def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs): - """ - Constructs a Faster R-CNN model with a MobileNetV3-Large backbone. It works similarly - to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. - - Example:: - - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large(pretrained=True) - >>> model.eval() - >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - >>> predictions = model(x) - - Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 - progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. - Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. - min_size (int): minimum size of the image to be rescaled before feeding it to the backbone - max_size (int): maximum size of the image to be rescaled before feeding it to the backbone - """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) - - if pretrained: - pretrained_backbone = False - backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, False, - trainable_layers=trainable_backbone_layers) - - anchor_sizes = ((32, 64, 128, 256, 512, ), ) - aspect_ratios = ((0.5, 1.0, 2.0), ) - - model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - min_size=min_size, max_size=max_size, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_coco'], progress=progress) - model.load_state_dict(state_dict) - return model - - def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05, **kwargs): From af92474c8798d12570a3f968c65d85c3e334af79 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 18 Jan 2021 11:05:17 +0000 Subject: [PATCH 09/11] Update documentation and adding weights. --- docs/source/models.rst | 33 +++++++++++---------- references/detection/README.md | 9 +++++- torchvision/models/detection/faster_rcnn.py | 3 +- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 658148e191c..fa2dfec14d9 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -358,13 +358,14 @@ models return the predictions of the following classes: Here are the summary of the accuracies for the models trained on the instances set of COCO train2017 and evaluated on COCO val2017. -================================ ======= ======== =========== -Network box AP mask AP keypoint AP -================================ ======= ======== =========== -Faster R-CNN ResNet-50 FPN 37.0 - - -RetinaNet ResNet-50 FPN 36.4 - - -Mask R-CNN ResNet-50 FPN 37.9 34.6 - -================================ ======= ======== =========== +================================== ======= ======== =========== +Network box AP mask AP keypoint AP +================================== ======= ======== =========== +Faster R-CNN ResNet-50 FPN 37.0 - - +Faster R-CNN MobileNetV3-Large FPN 23.0 - - +RetinaNet ResNet-50 FPN 36.4 - - +Mask R-CNN ResNet-50 FPN 37.9 34.6 - +================================== ======= ======== =========== For person keypoint detection, the accuracies for the pre-trained models are as follows @@ -414,20 +415,22 @@ For test time, we report the time for the model evaluation and postprocessing (including mask pasting in image), but not the time for computing the precision-recall. -============================== =================== ================== =========== -Network train time (s / it) test time (s / it) memory (GB) -============================== =================== ================== =========== -Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 -RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 -Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 -Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 -============================== =================== ================== =========== +================================== =================== ================== =========== +Network train time (s / it) test time (s / it) memory (GB) +================================== =================== ================== =========== +Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 +Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6 +RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 +Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 +Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 +================================== =================== ================== =========== Faster R-CNN ------------ .. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn +.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn RetinaNet diff --git a/references/detection/README.md b/references/detection/README.md index f89e8149a71..e7ac6e48e11 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -20,13 +20,20 @@ You must modify the following flags: Except otherwise noted, all models have been trained on 8x V100 GPUs. -### Faster R-CNN +### Faster R-CNN ResNet-50 FPN ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ --lr-steps 16 22 --aspect-ratio-group-factor 3 ``` +### Faster R-CNN MobileNetV3-Large FPN +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\ + --lr-steps 16 22 --aspect-ratio-group-factor 3 +``` + ### RetinaNet ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index ae5150205a7..f5b4696e2ce 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -288,7 +288,8 @@ def forward(self, x): model_urls = { 'fasterrcnn_resnet50_fpn_coco': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', - 'fasterrcnn_mobilenet_v3_large_fpn_coco': None, # TODO: Add the final model url + 'fasterrcnn_mobilenet_v3_large_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth', } From b3cda8bdbe1c1742884f473cb390a528c311a337 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 18 Jan 2021 14:10:22 +0000 Subject: [PATCH 10/11] Use buildin Identity. --- torchvision/models/mobilenetv3.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3e59a6f58cf..eba4823277b 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -18,16 +18,6 @@ } -class Identity(nn.Module): - - def __init__(self, inplace: bool = False): - super().__init__() - self.inplace = inplace - - def forward(self, input: Tensor) -> Tensor: - return input - - class SqueezeExcitation(nn.Module): def __init__(self, input_channels: int, squeeze_factor: int = 4): @@ -88,7 +78,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod # project layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=Identity)) + activation_layer=nn.Identity)) self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels From 681e5d11f4262cb650fa38295b4aa0f8a652e586 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 18 Jan 2021 15:37:02 +0000 Subject: [PATCH 11/11] Fix spelling. --- torchvision/models/detection/backbone_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 6290f0373c2..dce7f038370 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -138,12 +138,12 @@ def mobilenet_backbone( # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] - num_stages = len(stage_indeces) + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] + num_stages = len(stage_indices) # find the index of the layer from which we wont freeze assert 0 <= trainable_layers <= num_stages - freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers] + freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] # freeze layers only if pretrained backbone is used for b in backbone[:freeze_before]: @@ -158,9 +158,9 @@ def mobilenet_backbone( if returned_layers is None: returned_layers = [num_stages - 2, num_stages - 1] assert min(returned_layers) >= 0 and max(returned_layers) < num_stages - return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)} + return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)} - in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers] + in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) else: m = nn.Sequential(