Skip to content

Commit 5c15a2c

Browse files
committed
Reduce resolution and increase number of anchor sizes.
1 parent 75933da commit 5c15a2c

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed
Binary file not shown.

torchvision/models/detection/retinanet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
628628

629629

630630
def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
631-
trainable_backbone_layers=None, **kwargs):
631+
trainable_backbone_layers=None, min_size=600, max_size=1000, **kwargs):
632632
"""
633633
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly
634634
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details.
@@ -647,6 +647,8 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classe
647647
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
648648
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
649649
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
650+
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
651+
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
650652
"""
651653
# check default parameters and by default set it to 3 if possible
652654
trainable_backbone_layers = _validate_trainable_layers(
@@ -657,10 +659,11 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classe
657659
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5],
658660
trainable_layers=trainable_backbone_layers)
659661

660-
anchor_sizes = ((128,), (256,), (512,))
662+
anchor_sizes = ((32, 64, 128, 256, 512), ) * 3
661663
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
662664

663-
model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs)
665+
model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
666+
min_size=min_size, max_size=max_size, **kwargs)
664667
if pretrained:
665668
state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'],
666669
progress=progress)

0 commit comments

Comments
 (0)