Skip to content

Commit 690ee55

Browse files
committed
Adding trainable_backbone_layers param on the train script.
1 parent 217e5b1 commit 690ee55

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

references/detection/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def main(args):
9292
collate_fn=utils.collate_fn)
9393

9494
print("Creating model")
95-
kwargs = {}
95+
kwargs = {
96+
"trainable_backbone_layers": args.trainable_backbone_layers
97+
}
9698
if "rcnn" in args.model:
9799
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
98100
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
@@ -178,6 +180,8 @@ def main(args):
178180
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
179181
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
180182
parser.add_argument('--rpn-score-thresh', default=0.0, type=float, help='rpn score threshold for faster-rcnn')
183+
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
184+
help='number of trainable layers of backbone ')
181185
parser.add_argument(
182186
"--test-only",
183187
dest="test_only",

0 commit comments

Comments
 (0)