Skip to content

Commit f4b907d

Browse files
committed
Add cosine annealing on trainer.
1 parent f318332 commit f4b907d

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

references/detection/train.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,14 @@ def main(args):
108108
optimizer = torch.optim.SGD(
109109
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
110110

111-
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
112-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
111+
args.lr_scheduler = args.lr_scheduler.lower()
112+
if args.lr_scheduler == 'multisteplr':
113+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
114+
elif args.lr_scheduler == 'cosineannealinglr':
115+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(data_loader) * args.epochs)
116+
else:
117+
raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
118+
"are supported.".format(args.lr_scheduler))
113119

114120
if args.resume:
115121
checkpoint = torch.load(args.resume, map_location='cpu')
@@ -169,9 +175,13 @@ def main(args):
169175
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
170176
metavar='W', help='weight decay (default: 1e-4)',
171177
dest='weight_decay')
172-
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
173-
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
174-
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
178+
parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)')
179+
parser.add_argument('--lr-step-size', default=8, type=int,
180+
help='decrease lr every step-size epochs (multisteplr scheduler only)')
181+
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
182+
help='decrease lr every step-size epochs (multisteplr scheduler only)')
183+
parser.add_argument('--lr-gamma', default=0.1, type=float,
184+
help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)')
175185
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
176186
parser.add_argument('--output-dir', default='.', help='path where to save')
177187
parser.add_argument('--resume', default='', help='resume from checkpoint')

0 commit comments

Comments
 (0)