@@ -108,8 +108,14 @@ def main(args):
108
108
optimizer = torch .optim .SGD (
109
109
params , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
110
110
111
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
112
- lr_scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , milestones = args .lr_steps , gamma = args .lr_gamma )
111
+ args .lr_scheduler = args .lr_scheduler .lower ()
112
+ if args .lr_scheduler == 'multisteplr' :
113
+ lr_scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , milestones = args .lr_steps , gamma = args .lr_gamma )
114
+ elif args .lr_scheduler == 'cosineannealinglr' :
115
+ lr_scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer , T_max = len (data_loader ) * args .epochs )
116
+ else :
117
+ raise RuntimeError ("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
118
+ "are supported." .format (args .lr_scheduler ))
113
119
114
120
if args .resume :
115
121
checkpoint = torch .load (args .resume , map_location = 'cpu' )
@@ -169,9 +175,13 @@ def main(args):
169
175
parser .add_argument ('--wd' , '--weight-decay' , default = 1e-4 , type = float ,
170
176
metavar = 'W' , help = 'weight decay (default: 1e-4)' ,
171
177
dest = 'weight_decay' )
172
- parser .add_argument ('--lr-step-size' , default = 8 , type = int , help = 'decrease lr every step-size epochs' )
173
- parser .add_argument ('--lr-steps' , default = [16 , 22 ], nargs = '+' , type = int , help = 'decrease lr every step-size epochs' )
174
- parser .add_argument ('--lr-gamma' , default = 0.1 , type = float , help = 'decrease lr by a factor of lr-gamma' )
178
+ parser .add_argument ('--lr-scheduler' , default = "multisteplr" , help = 'the lr scheduler (default: multisteplr)' )
179
+ parser .add_argument ('--lr-step-size' , default = 8 , type = int ,
180
+ help = 'decrease lr every step-size epochs (multisteplr scheduler only)' )
181
+ parser .add_argument ('--lr-steps' , default = [16 , 22 ], nargs = '+' , type = int ,
182
+ help = 'decrease lr every step-size epochs (multisteplr scheduler only)' )
183
+ parser .add_argument ('--lr-gamma' , default = 0.1 , type = float ,
184
+ help = 'decrease lr by a factor of lr-gamma (multisteplr scheduler only)' )
175
185
parser .add_argument ('--print-freq' , default = 20 , type = int , help = 'print frequency' )
176
186
parser .add_argument ('--output-dir' , default = '.' , help = 'path where to save' )
177
187
parser .add_argument ('--resume' , default = '' , help = 'resume from checkpoint' )
0 commit comments