Skip to content

Commit f40c8df

Browse files
xiaohu2015datumbox
andauthored
Simplify EMA to use Pytorch's update_parameters (#5469)
Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 5568744 commit f40c8df

File tree

1 file changed

+1
-11
lines changed

1 file changed

+1
-11
lines changed

references/classification/utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,7 @@ def __init__(self, model, decay, device="cpu"):
166166
def ema_avg(avg_model_param, model_param, num_averaged):
167167
return decay * avg_model_param + (1 - decay) * model_param
168168

169-
super().__init__(model, device, ema_avg)
170-
171-
def update_parameters(self, model):
172-
for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
173-
device = p_swa.device
174-
p_model_ = p_model.detach().to(device)
175-
if self.n_averaged == 0:
176-
p_swa.detach().copy_(p_model_)
177-
else:
178-
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)))
179-
self.n_averaged += 1
169+
super().__init__(model, device, ema_avg, use_buffers=True)
180170

181171

182172
def accuracy(output, target, topk=(1,)):

0 commit comments

Comments
 (0)