Skip to content

Commit 5198385

Browse files
committed
Adding rmsprop support on the train.py
1 parent 25f8b26 commit 5198385

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

references/classification/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,15 @@ def main(args):
173173

174174
criterion = nn.CrossEntropyLoss()
175175

176-
optimizer = torch.optim.SGD(
177-
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
176+
opt_name = args.opt.lower()
177+
if opt_name == 'sgd':
178+
optimizer = torch.optim.SGD(
179+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
180+
elif opt_name == 'rmsprop':
181+
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
182+
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
183+
else:
184+
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
178185

179186
if args.apex:
180187
model, optimizer = amp.initialize(model, optimizer,
@@ -238,6 +245,7 @@ def parse_args():
238245
help='number of total epochs to run')
239246
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
240247
help='number of data loading workers (default: 16)')
248+
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
241249
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
242250
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
243251
help='momentum')

0 commit comments

Comments
 (0)