Skip to content

Commit e06be30

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Simplify EMA to use Pytorch's update_parameters (#5469)
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]> Reviewed By: datumbox Differential Revision: D34579515 fbshipit-source-id: 6f563a48305dc1c9d99274d40c15416075c9b20f
1 parent dde9098 commit e06be30

File tree

1 file changed

+1
-11
lines changed

1 file changed

+1
-11
lines changed

references/classification/utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,7 @@ def __init__(self, model, decay, device="cpu"):
166166
def ema_avg(avg_model_param, model_param, num_averaged):
167167
return decay * avg_model_param + (1 - decay) * model_param
168168

169-
super().__init__(model, device, ema_avg)
170-
171-
def update_parameters(self, model):
172-
for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
173-
device = p_swa.device
174-
p_model_ = p_model.detach().to(device)
175-
if self.n_averaged == 0:
176-
p_swa.detach().copy_(p_model_)
177-
else:
178-
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)))
179-
self.n_averaged += 1
169+
super().__init__(model, device, ema_avg, use_buffers=True)
180170

181171

182172
def accuracy(output, target, topk=(1,)):

0 commit comments

Comments
 (0)