diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index d4df7adacb88..f9d17b740d24 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -526,7 +526,7 @@ def transform_images(examples): logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} if args.use_ema: - logs["ema_decay"] = ema_model.decay + logs["ema_decay"] = ema_model.cur_decay_value progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) progress_bar.close() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c5449556a12f..8b05e886a9a3 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -122,6 +122,7 @@ def __init__( self.inv_gamma = inv_gamma self.power = power self.optimization_step = 0 + self.cur_decay_value = None # set in `step()` def get_decay(self, optimization_step: int) -> float: """ @@ -163,6 +164,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): # Compute the decay factor for the exponential moving average. decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay one_minus_decay = 1 - decay for s_param, param in zip(self.shadow_params, parameters): @@ -208,7 +210,7 @@ def state_dict(self) -> dict: # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, - "min_decay": self.decay, + "min_decay": self.min_decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup,