Skip to content

Commit 385e077

Browse files
committed
Adding rmsprop support and allowing warm restarts on the train.py
1 parent 25f8b26 commit 385e077

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

references/classification/train.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,15 @@ def main(args):
173173

174174
criterion = nn.CrossEntropyLoss()
175175

176-
optimizer = torch.optim.SGD(
177-
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
176+
opt_name = args.opt.lower()
177+
if opt_name == 'sgd':
178+
optimizer = torch.optim.SGD(
179+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
180+
elif opt_name == 'rmsprop':
181+
optimizer = torch.optim.RMSprop(
182+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
183+
else:
184+
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
178185

179186
if args.apex:
180187
model, optimizer = amp.initialize(model, optimizer,
@@ -191,9 +198,11 @@ def main(args):
191198
if args.resume:
192199
checkpoint = torch.load(args.resume, map_location='cpu')
193200
model_without_ddp.load_state_dict(checkpoint['model'])
194-
optimizer.load_state_dict(checkpoint['optimizer'])
195-
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
196-
args.start_epoch = checkpoint['epoch'] + 1
201+
if not args.no_resume_opt:
202+
optimizer.load_state_dict(checkpoint['optimizer'])
203+
if not args.no_resume_sched:
204+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
205+
args.start_epoch = checkpoint['epoch'] + 1
197206

198207
if args.test_only:
199208
evaluate(model, criterion, data_loader_test, device=device)
@@ -238,6 +247,7 @@ def parse_args():
238247
help='number of total epochs to run')
239248
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
240249
help='number of data loading workers (default: 16)')
250+
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
241251
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
242252
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
243253
help='momentum')
@@ -275,6 +285,18 @@ def parse_args():
275285
help="Use pre-trained models from the modelzoo",
276286
action="store_true",
277287
)
288+
parser.add_argument(
289+
"--no-resume-opt",
290+
dest="no_resume_opt",
291+
help="When resuming from checkpoint it ignores the optimizer state",
292+
action="store_true",
293+
)
294+
parser.add_argument(
295+
"--no-resume-sched",
296+
dest="no_resume_sched",
297+
help="When resuming from checkpoint it ignores the scheduler state",
298+
action="store_true",
299+
)
278300

279301
# Mixed precision training parameters
280302
parser.add_argument('--apex', action='store_true',

0 commit comments

Comments
 (0)