-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Detection recipe enhancements #5715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,7 @@ def get_args_parser(add_help=True): | |
parser.add_argument( | ||
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" | ||
) | ||
parser.add_argument("--opt", default="sgd", type=str, help="optimizer") | ||
parser.add_argument( | ||
"--lr", | ||
default=0.02, | ||
|
@@ -84,6 +85,12 @@ def get_args_parser(add_help=True): | |
help="weight decay (default: 1e-4)", | ||
dest="weight_decay", | ||
) | ||
parser.add_argument( | ||
"--norm-weight-decay", | ||
default=None, | ||
type=float, | ||
help="weight decay for Normalization layers (default: None, same value as --wd)", | ||
) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser.add_argument( | ||
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)" | ||
) | ||
|
@@ -176,6 +183,8 @@ def main(args): | |
|
||
print("Creating model") | ||
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} | ||
if args.data_augmentation in ["multiscale", "lsj"]: | ||
kwargs["_skip_resize"] = True | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if "rcnn" in args.model: | ||
if args.rpn_score_thresh is not None: | ||
kwargs["rpn_score_thresh"] = args.rpn_score_thresh | ||
|
@@ -191,8 +200,26 @@ def main(args): | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | ||
model_without_ddp = model.module | ||
|
||
params = [p for p in model.parameters() if p.requires_grad] | ||
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) | ||
if args.norm_weight_decay is None: | ||
parameters = [p for p in model.parameters() if p.requires_grad] | ||
else: | ||
param_groups = torchvision.ops._utils.split_normalization_params(model) | ||
wd_groups = [args.norm_weight_decay, args.weight_decay] | ||
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] | ||
|
||
opt_name = args.opt.lower() | ||
if opt_name.startswith("sgd"): | ||
optimizer = torch.optim.SGD( | ||
parameters, | ||
lr=args.lr, | ||
momentum=args.momentum, | ||
weight_decay=args.weight_decay, | ||
nesterov="nesterov" in opt_name, | ||
) | ||
elif opt_name == "adamw": | ||
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) | ||
else: | ||
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Straight copy-paste from classification. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since is copy-pasted I think it can stay as is and if we want to change we can do in a different PR, but still wonder if there is a reson why for sgd we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right! I've indeed copy-pasted it from classification but replaced the previous sgd optimizer line. The problem is that the existing recipe didn't contain the nesterov momentum update. I've just updated the file to support it; it's not something I used so far but it's a simple update. |
||
|
||
scaler = torch.cuda.amp.GradScaler() if args.amp else None | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,6 @@ def test_get_weight(name, weight): | |
) | ||
def test_naming_conventions(model_fn): | ||
weights_enum = _get_model_weights(model_fn) | ||
print(weights_enum) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. profit! |
||
assert weights_enum is not None | ||
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.