Skip to content

Commit 12fd3a6

Browse files
authored
Added Exponential Moving Average support to classification reference script (#4381)
* Added Exponential Moving Average support to classification reference script * Addressed review comments * Updated model argument
1 parent c50d0fc commit 12fd3a6

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

references/classification/train.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
amp = None
1818

1919

20-
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
20+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
21+
print_freq, apex=False, model_ema=None):
2122
model.train()
2223
metric_logger = utils.MetricLogger(delimiter=" ")
2324
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
4546
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
4647
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
4748

49+
if model_ema:
50+
model_ema.update_parameters(model)
4851

49-
def evaluate(model, criterion, data_loader, device, print_freq=100):
52+
53+
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
5054
model.eval()
5155
metric_logger = utils.MetricLogger(delimiter=" ")
52-
header = 'Test:'
56+
header = f'Test: {log_suffix}'
5357
with torch.no_grad():
5458
for image, target in metric_logger.log_every(data_loader, print_freq, header):
5559
image = image.to(device, non_blocking=True)
@@ -199,12 +203,18 @@ def main(args):
199203
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
200204
model_without_ddp = model.module
201205

206+
model_ema = None
207+
if args.model_ema:
208+
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
209+
202210
if args.resume:
203211
checkpoint = torch.load(args.resume, map_location='cpu')
204212
model_without_ddp.load_state_dict(checkpoint['model'])
205213
optimizer.load_state_dict(checkpoint['optimizer'])
206214
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
207215
args.start_epoch = checkpoint['epoch'] + 1
216+
if model_ema:
217+
model_ema.load_state_dict(checkpoint['model_ema'])
208218

209219
if args.test_only:
210220
evaluate(model, criterion, data_loader_test, device=device)
@@ -215,16 +225,20 @@ def main(args):
215225
for epoch in range(args.start_epoch, args.epochs):
216226
if args.distributed:
217227
train_sampler.set_epoch(epoch)
218-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
228+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
219229
lr_scheduler.step()
220230
evaluate(model, criterion, data_loader_test, device=device)
231+
if model_ema:
232+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
221233
if args.output_dir:
222234
checkpoint = {
223235
'model': model_without_ddp.state_dict(),
224236
'optimizer': optimizer.state_dict(),
225237
'lr_scheduler': lr_scheduler.state_dict(),
226238
'epoch': epoch,
227239
'args': args}
240+
if model_ema:
241+
checkpoint['model_ema'] = model_ema.state_dict()
228242
utils.save_on_master(
229243
checkpoint,
230244
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
@@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
306320
parser.add_argument('--world-size', default=1, type=int,
307321
help='number of distributed processes')
308322
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
323+
parser.add_argument(
324+
'--model-ema', action='store_true',
325+
help='enable tracking Exponential Moving Average of model parameters')
326+
parser.add_argument(
327+
'--model-ema-decay', type=float, default=0.99,
328+
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')
309329

310330
return parser
311331

references/classification/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ def log_every(self, iterable, print_freq, header=None):
161161
print('{} Total time: {}'.format(header, total_time_str))
162162

163163

164+
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
165+
"""Maintains moving averages of model parameters using an exponential decay.
166+
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
167+
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
168+
is used to compute the EMA.
169+
"""
170+
def __init__(self, model, decay, device='cpu'):
171+
ema_avg = (lambda avg_model_param, model_param, num_averaged:
172+
decay * avg_model_param + (1 - decay) * model_param)
173+
super().__init__(model, device, ema_avg)
174+
175+
164176
def accuracy(output, target, topk=(1,)):
165177
"""Computes the accuracy over the k top predictions for the specified values of k"""
166178
with torch.no_grad():

0 commit comments

Comments
 (0)