Skip to content

Commit 8455b37

Browse files
committed
Add back nesterov momentum
1 parent ff6d641 commit 8455b37

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

references/detection/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,13 @@ def main(args):
209209

210210
opt_name = args.opt.lower()
211211
if opt_name.startswith("sgd"):
212-
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
212+
optimizer = torch.optim.SGD(
213+
parameters,
214+
lr=args.lr,
215+
momentum=args.momentum,
216+
weight_decay=args.weight_decay,
217+
nesterov="nesterov" in opt_name,
218+
)
213219
elif opt_name == "adamw":
214220
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
215221
else:

0 commit comments

Comments
 (0)