-
Notifications
You must be signed in to change notification settings - Fork 6.1k
add Min-SNR loss to Controlnet flax train script #3016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -806,6 +812,17 @@ def main(): | |||
validation_rng, train_rngs = jax.random.split(rng) | |||
train_rngs = jax.random.split(train_rngs, jax.local_device_count()) | |||
|
|||
def compute_snr(timesteps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add a comment on the reference like here?
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for doing this quickly.
Let's maybe also update the README mentioning this feature briefly?
thank you so much, this is great! Can you kindly also remove the line from the doc that this is only for Pytorch: https://github.com/huggingface/diffusers/blob/main/docs/source/en/training/text2image.mdx#training-with-min-snr-weighting |
@kashif I will remove that line once we updated the flax script for text2img (I only updated controlnet in this PR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty nice!
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | ||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | ||
), | ||
default="wandb", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we mean to change the default here? (Just asking)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just didn't want it to default to something that doesn't exist - cause wandb is the only method I implemented 😅 (and I don't think the other flax training scripts implemented any of these logging methods)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol, fair enough :)
* add wandb team and min-snr loss * make style * apply feedbacks
* add wandb team and min-snr loss * make style * apply feedbacks
* add wandb team and min-snr loss * make style * apply feedbacks
* add wandb team and min-snr loss * make style * apply feedbacks
(a very short) wandb experiment here
https://wandb.ai/yiyixu/train_controlnet_flax?workspace=user-yiyixu