Skip to content

EMA: fix state_dict() and load_state_dict() & add cur_decay_value #2146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 8, 2023
Merged

EMA: fix state_dict() and load_state_dict() & add cur_decay_value #2146

merged 5 commits into from
Feb 8, 2023

Conversation

chenguolin
Copy link
Contributor

  1. fix the saved value for min_decay in EMA self.state_dict().

  2. add an interface self.cur_decay_value to track the current value for decay.
    (as self.decay is a constant value meaning "max decay value")

  3. track the current EMA decay value by self.cur_decay_value in "unconditional image generation" examples, instead of self.decay.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 28, 2023

The documentation is not available anymore as the PR was closed or merged.

'float' object (`state_dict["power"]`) has no attribute 'get'.
@chenguolin
Copy link
Contributor Author

  1. fix a bug in EMA load_static_dict(): "AttributeError: 'float' object has no attribute 'get'"

@chenguolin chenguolin changed the title EMA: fix state_dict() & add cur_decay_value EMA: fix state_dict() and load_state_dict() & add cur_decay_value Jan 29, 2023
@patrickvonplaten
Copy link
Contributor

This looks good to me! @chenguolin could you maybe also check if the EMAModel class is correctly used in: https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py and potentially fix it? :-)

@chenguolin
Copy link
Contributor Author

It looks good. I have changed logs["ema_decay"] to the new ema_model.cur_decay_value to track the current decay value at:

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, @pcuenca @patil-suraj could you take a look?

@patrickvonplaten
Copy link
Contributor

Actually we should probably also apply the same changes to: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

In case you have 5min it'd be great if you could take a look at applying the same changes to train_text_to_image as wel @chenguolin :-)

cc @patil-suraj wdyt of this PR

@chenguolin
Copy link
Contributor Author

Hi @patrickvonplaten, the only necessary change to example "train_unconditional.py" is logs["ema_decay"] = ema_model.decay -> logs["ema_decay"] = ema_model.cur_decay_value.

While other examples don't use ema_model.decay, maybe there is no need to change them.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing this!

@patil-suraj
Copy link
Contributor

@chenguolin there's a merge conflict in train_unconditional_ort.py could you please fix it? It should be good to merge after that.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@chenguolin
Copy link
Contributor Author

Hi @patil-suraj, I have just deleted train_unconditional_ort.py ane merged.

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing test is un-related :)

@patil-suraj patil-suraj merged commit 9d0d070 into huggingface:main Feb 8, 2023
yiyixuxu pushed a commit to evinpinar/diffusers-attend-and-excite-pipeline that referenced this pull request Feb 16, 2023
…e` (huggingface#2146)

* EMA: fix `state_dict()` & add `cur_decay_value`

* EMA: fix a bug in `load_state_dict()`

'float' object (`state_dict["power"]`) has no attribute 'get'.

* del train_unconditional_ort.py
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…e` (huggingface#2146)

* EMA: fix `state_dict()` & add `cur_decay_value`

* EMA: fix a bug in `load_state_dict()`

'float' object (`state_dict["power"]`) has no attribute 'get'.

* del train_unconditional_ort.py
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…e` (huggingface#2146)

* EMA: fix `state_dict()` & add `cur_decay_value`

* EMA: fix a bug in `load_state_dict()`

'float' object (`state_dict["power"]`) has no attribute 'get'.

* del train_unconditional_ort.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants