Skip to content

Commit 24ecd45

Browse files
committed
Adding rpn_score_thresh param directly in fasterrcnn_mobilenet_v3_large_fpn.
1 parent 690ee55 commit 24ecd45

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

references/detection/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def main(args):
9696
"trainable_backbone_layers": args.trainable_backbone_layers
9797
}
9898
if "rcnn" in args.model:
99-
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
99+
if args.rpn_score_thresh is not None:
100+
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
100101
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
101102
**kwargs)
102103
model.to(device)
@@ -179,9 +180,9 @@ def main(args):
179180
parser.add_argument('--resume', default='', help='resume from checkpoint')
180181
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
181182
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
182-
parser.add_argument('--rpn-score-thresh', default=0.0, type=float, help='rpn score threshold for faster-rcnn')
183+
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
183184
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
184-
help='number of trainable layers of backbone ')
185+
help='number of trainable layers of backbone')
185186
parser.add_argument(
186187
"--test-only",
187188
dest="test_only",

torchvision/models/detection/faster_rcnn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9
414414

415415

416416
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
417-
trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs):
417+
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
418+
**kwargs):
418419
"""
419420
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
420421
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
@@ -435,6 +436,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class
435436
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
436437
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
437438
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
439+
rpn_score_thresh (float): during inference, only return proposals with a classification score
440+
greater than rpn_score_thresh
438441
"""
439442
trainable_backbone_layers = _validate_trainable_layers(
440443
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
@@ -448,7 +451,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class
448451
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
449452

450453
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
451-
min_size=min_size, max_size=max_size, **kwargs)
454+
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
452455
if pretrained:
453456
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
454457
model.load_state_dict(state_dict)

0 commit comments

Comments
 (0)