Skip to content

Commit 9d0d070

Browse files
authored
EMA: fix state_dict() and load_state_dict() & add cur_decay_value (#2146)
* EMA: fix `state_dict()` & add `cur_decay_value` * EMA: fix a bug in `load_state_dict()` 'float' object (`state_dict["power"]`) has no attribute 'get'. * del train_unconditional_ort.py
1 parent c1971a5 commit 9d0d070

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def transform_images(examples):
563563

564564
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
565565
if args.use_ema:
566-
logs["ema_decay"] = ema_model.decay
566+
logs["ema_decay"] = ema_model.cur_decay_value
567567
progress_bar.set_postfix(**logs)
568568
accelerator.log(logs, step=global_step)
569569
progress_bar.close()

src/diffusers/training_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
self.inv_gamma = inv_gamma
125125
self.power = power
126126
self.optimization_step = 0
127+
self.cur_decay_value = None # set in `step()`
127128

128129
self.model_cls = model_cls
129130
self.model_config = model_config
@@ -194,6 +195,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
194195

195196
# Compute the decay factor for the exponential moving average.
196197
decay = self.get_decay(self.optimization_step)
198+
self.cur_decay_value = decay
197199
one_minus_decay = 1 - decay
198200

199201
for s_param, param in zip(self.shadow_params, parameters):
@@ -239,7 +241,7 @@ def state_dict(self) -> dict:
239241
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
240242
return {
241243
"decay": self.decay,
242-
"min_decay": self.decay,
244+
"min_decay": self.min_decay,
243245
"optimization_step": self.optimization_step,
244246
"update_after_step": self.update_after_step,
245247
"use_ema_warmup": self.use_ema_warmup,

0 commit comments

Comments
 (0)