@@ -628,7 +628,7 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
628
628
629
629
630
630
def retinanet_mobilenet_v3_large_fpn (pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True ,
631
- trainable_backbone_layers = None , ** kwargs ):
631
+ trainable_backbone_layers = None , min_size = 600 , max_size = 1000 , ** kwargs ):
632
632
"""
633
633
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly
634
634
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details.
@@ -647,6 +647,8 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classe
647
647
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
648
648
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
649
649
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
650
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
651
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
650
652
"""
651
653
# check default parameters and by default set it to 3 if possible
652
654
trainable_backbone_layers = _validate_trainable_layers (
@@ -657,10 +659,11 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classe
657
659
backbone = mobilenet_fpn_backbone ("mobilenet_v3_large" , pretrained_backbone , returned_layers = [4 , 5 ],
658
660
trainable_layers = trainable_backbone_layers )
659
661
660
- anchor_sizes = ((128 ,), ( 256 ,), ( 512 ,))
662
+ anchor_sizes = ((32 , 64 , 128 , 256 , 512 ), ) * 3
661
663
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
662
664
663
- model = RetinaNet (backbone , num_classes , anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios ), ** kwargs )
665
+ model = RetinaNet (backbone , num_classes , anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios ),
666
+ min_size = min_size , max_size = max_size , ** kwargs )
664
667
if pretrained :
665
668
state_dict = load_state_dict_from_url (model_urls ['retinanet_mobilenet_v3_large_fpn_coco' ],
666
669
progress = progress )
0 commit comments