@@ -173,8 +173,15 @@ def main(args):
173
173
174
174
criterion = nn .CrossEntropyLoss ()
175
175
176
- optimizer = torch .optim .SGD (
177
- model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
176
+ opt_name = args .opt .lower ()
177
+ if opt_name == 'sgd' :
178
+ optimizer = torch .optim .SGD (
179
+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
180
+ elif opt_name == 'rmsprop' :
181
+ optimizer = torch .optim .RMSprop (
182
+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
183
+ else :
184
+ raise RuntimeError ("Invalid optimizer {}. Only SGD and RMSprop are supported." .format (args .opt ))
178
185
179
186
if args .apex :
180
187
model , optimizer = amp .initialize (model , optimizer ,
@@ -191,9 +198,11 @@ def main(args):
191
198
if args .resume :
192
199
checkpoint = torch .load (args .resume , map_location = 'cpu' )
193
200
model_without_ddp .load_state_dict (checkpoint ['model' ])
194
- optimizer .load_state_dict (checkpoint ['optimizer' ])
195
- lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
196
- args .start_epoch = checkpoint ['epoch' ] + 1
201
+ if not args .no_resume_opt :
202
+ optimizer .load_state_dict (checkpoint ['optimizer' ])
203
+ if not args .no_resume_sched :
204
+ lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
205
+ args .start_epoch = checkpoint ['epoch' ] + 1
197
206
198
207
if args .test_only :
199
208
evaluate (model , criterion , data_loader_test , device = device )
@@ -238,6 +247,7 @@ def parse_args():
238
247
help = 'number of total epochs to run' )
239
248
parser .add_argument ('-j' , '--workers' , default = 16 , type = int , metavar = 'N' ,
240
249
help = 'number of data loading workers (default: 16)' )
250
+ parser .add_argument ('--opt' , default = 'sgd' , type = str , help = 'optimizer' )
241
251
parser .add_argument ('--lr' , default = 0.1 , type = float , help = 'initial learning rate' )
242
252
parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
243
253
help = 'momentum' )
@@ -275,6 +285,18 @@ def parse_args():
275
285
help = "Use pre-trained models from the modelzoo" ,
276
286
action = "store_true" ,
277
287
)
288
+ parser .add_argument (
289
+ "--no-resume-opt" ,
290
+ dest = "no_resume_opt" ,
291
+ help = "When resuming from checkpoint it ignores the optimizer state" ,
292
+ action = "store_true" ,
293
+ )
294
+ parser .add_argument (
295
+ "--no-resume-sched" ,
296
+ dest = "no_resume_sched" ,
297
+ help = "When resuming from checkpoint it ignores the scheduler state" ,
298
+ action = "store_true" ,
299
+ )
278
300
279
301
# Mixed precision training parameters
280
302
parser .add_argument ('--apex' , action = 'store_true' ,
0 commit comments