17
17
amp = None
18
18
19
19
20
- def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , print_freq , apex = False ):
20
+ def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch ,
21
+ print_freq , apex = False , model_ema = None ):
21
22
model .train ()
22
23
metric_logger = utils .MetricLogger (delimiter = " " )
23
24
metric_logger .add_meter ('lr' , utils .SmoothedValue (window_size = 1 , fmt = '{value}' ))
@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
45
46
metric_logger .meters ['acc5' ].update (acc5 .item (), n = batch_size )
46
47
metric_logger .meters ['img/s' ].update (batch_size / (time .time () - start_time ))
47
48
49
+ if model_ema :
50
+ model_ema .update_parameters (model )
48
51
49
- def evaluate (model , criterion , data_loader , device , print_freq = 100 ):
52
+
53
+ def evaluate (model , criterion , data_loader , device , print_freq = 100 , log_suffix = '' ):
50
54
model .eval ()
51
55
metric_logger = utils .MetricLogger (delimiter = " " )
52
- header = 'Test:'
56
+ header = f 'Test: { log_suffix } '
53
57
with torch .no_grad ():
54
58
for image , target in metric_logger .log_every (data_loader , print_freq , header ):
55
59
image = image .to (device , non_blocking = True )
@@ -199,12 +203,18 @@ def main(args):
199
203
model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
200
204
model_without_ddp = model .module
201
205
206
+ model_ema = None
207
+ if args .model_ema :
208
+ model_ema = utils .ExponentialMovingAverage (model_without_ddp , device = device , decay = args .model_ema_decay )
209
+
202
210
if args .resume :
203
211
checkpoint = torch .load (args .resume , map_location = 'cpu' )
204
212
model_without_ddp .load_state_dict (checkpoint ['model' ])
205
213
optimizer .load_state_dict (checkpoint ['optimizer' ])
206
214
lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
207
215
args .start_epoch = checkpoint ['epoch' ] + 1
216
+ if model_ema :
217
+ model_ema .load_state_dict (checkpoint ['model_ema' ])
208
218
209
219
if args .test_only :
210
220
evaluate (model , criterion , data_loader_test , device = device )
@@ -215,16 +225,20 @@ def main(args):
215
225
for epoch in range (args .start_epoch , args .epochs ):
216
226
if args .distributed :
217
227
train_sampler .set_epoch (epoch )
218
- train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .apex )
228
+ train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .apex , model_ema )
219
229
lr_scheduler .step ()
220
230
evaluate (model , criterion , data_loader_test , device = device )
231
+ if model_ema :
232
+ evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = 'EMA' )
221
233
if args .output_dir :
222
234
checkpoint = {
223
235
'model' : model_without_ddp .state_dict (),
224
236
'optimizer' : optimizer .state_dict (),
225
237
'lr_scheduler' : lr_scheduler .state_dict (),
226
238
'epoch' : epoch ,
227
239
'args' : args }
240
+ if model_ema :
241
+ checkpoint ['model_ema' ] = model_ema .state_dict ()
228
242
utils .save_on_master (
229
243
checkpoint ,
230
244
os .path .join (args .output_dir , 'model_{}.pth' .format (epoch )))
@@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
306
320
parser .add_argument ('--world-size' , default = 1 , type = int ,
307
321
help = 'number of distributed processes' )
308
322
parser .add_argument ('--dist-url' , default = 'env://' , help = 'url used to set up distributed training' )
323
+ parser .add_argument (
324
+ '--model-ema' , action = 'store_true' ,
325
+ help = 'enable tracking Exponential Moving Average of model parameters' )
326
+ parser .add_argument (
327
+ '--model-ema-decay' , type = float , default = 0.99 ,
328
+ help = 'decay factor for Exponential Moving Average of model parameters(default: 0.99)' )
309
329
310
330
return parser
311
331
0 commit comments