@@ -173,8 +173,15 @@ def main(args):
173
173
174
174
criterion = nn .CrossEntropyLoss ()
175
175
176
- optimizer = torch .optim .SGD (
177
- model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
176
+ opt_name = args .opt .lower ()
177
+ if opt_name == 'sgd' :
178
+ optimizer = torch .optim .SGD (
179
+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
180
+ elif opt_name == 'rmsprop' :
181
+ optimizer = torch .optim .RMSprop (
182
+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
183
+ else :
184
+ raise RuntimeError ("Invalid optimizer {}. Only SGD and RMSprop are supported." .format (args .opt ))
178
185
179
186
if args .apex :
180
187
model , optimizer = amp .initialize (model , optimizer ,
@@ -238,6 +245,7 @@ def parse_args():
238
245
help = 'number of total epochs to run' )
239
246
parser .add_argument ('-j' , '--workers' , default = 16 , type = int , metavar = 'N' ,
240
247
help = 'number of data loading workers (default: 16)' )
248
+ parser .add_argument ('--opt' , default = 'sgd' , type = str , help = 'optimizer' )
241
249
parser .add_argument ('--lr' , default = 0.1 , type = float , help = 'initial learning rate' )
242
250
parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
243
251
help = 'momentum' )
0 commit comments