-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix schedulers zero SNR and rescale classifier free guidance #3664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix schedulers zero SNR and rescale classifier free guidance #3664
Conversation
I didn't implement the beta rescaling in all the |
@@ -76,6 +76,41 @@ def alpha_bar(time_step): | |||
return torch.tensor(betas, dtype=torch.float32) | |||
|
|||
|
|||
def rescale_zero_terminal_snr(alphas_cumprod): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we refactor this method to take betas as parameters instead of alphas_cumprod?
Because its weird to accept alphas_cumprod and return betas.
Also, your method comment says betas as the argument.
After all, do we really want an argument that does the rescaling by the scheduler?
Technically, people can always do their betas transformation outside and pass it in as trained_betas.
Also, cosine schedule should not use this to rescale, it should just remove the beta clipping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking Betas as an argument sounds good to me, I just changed it to alphas_cumprod
because we had the variable already calculated in the schedulers.
@@ -560,6 +560,7 @@ def __call__( | |||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |||
callback_steps: int = 1, | |||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |||
guidance_rescale: float = 0.7, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should default it to 0, so it doesn't break old behavior.
@@ -706,8 +712,20 @@ def __call__( | |||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |||
|
|||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |||
std_text = torch.std(noise_pred_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect. The std needs to be calculated as a single real number per image, not per pixel. Change it to:
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True)
The documentation is not available anymore as the PR was closed or merged. |
Hey @Max-We, Great job opening the PR. I've played around quite a bit with all the improvements and my conclusion is that we can simply not say that the suggestions from the paper are always improving the model performance for our pretrained checkpoints which are used >1M times per month. I've changed this PR so that only the DDIM scheduler and the stable diffusion pipeline is adapted (also the changes didn't work with other schedulers for the most part). pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing", rescale_betas_zero_snr=True)
...
pipe(..., guidance_rescale=0.7) The fourth improvement requires training (Section 3.2) which requires new checkpoints. It's pretty easy to fine-tune a checkpoint on v-loss and v-prediction using our text-to-image script here. I ran the changes on a prompt of the paper for both v1.5 and SDv2.1: from diffusers import StableDiffusionPipeline, DDIMScheduler
import time
import torch
import sys
path = "runwayml/stable-diffusion-v1-5"
# path = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
for TIMESTEP_TYPE in ["trailing", "leading"]:
for RESCALE_BETAS_ZEROS_SNR in [True, False]:
for GUIDANCE_RESCALE in [0,0, 0.7]:
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling=TIMESTEP_TYPE, rescale_betas_zero_snr=RESCALE_BETAS_ZEROS_SNR)
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipe(prompt=prompt, generator=generator, num_images_per_prompt=4, num_inference_steps=40, guidance_rescale=GUIDANCE_RESCALE).images Results in the next post. |
Note: SD v1.5
Default imagesVS Images with
|
Note: SD v2.1
Default imagesVS Images with
|
Long story, short. I don't think that this paper is that ground-breaking and universally applicable. I think with this PR we give the users the possibility to figure out for themselves what they like best, but there is IMO no need to make it available everywhere yet. |
@Max-We @PeterL1n @sayakpaul thoughts? |
@patrickvonplaten please try with edit: i have set it up as a test bed, now. |
@patrickvonplaten I think it's a good idea to keep the changes as an opt-in feature. The changes at this point look good to me. Thanks for your assistance (also to you @PeterL1n)! |
@@ -76,6 +76,41 @@ def alpha_bar(time_step): | |||
return torch.tensor(betas, dtype=torch.float32) | |||
|
|||
|
|||
def rescale_zero_terminal_snr(alphas_cumprod): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking Betas as an argument sounds good to me, I just changed it to alphas_cumprod
because we had the variable already calculated in the schedulers.
def rescale_zero_terminal_snr(alphas_cumprod): | ||
""" | ||
Rescales betas to have zero terminal SNR (signal-to-noise-ratio) Based on https://arxiv.org/pdf/2305.08891.pdf | ||
(Algorithm 1) | ||
|
||
|
||
Args: | ||
betas (`torch.FloatTensor`): | ||
the betas that the scheduler is being initialized with. | ||
|
||
Returns: | ||
`torch.FloatTensor`: rescaled betas with zero terminal SNR | ||
""" | ||
# Convert betas to alphas_bar_sqrt | ||
alphas_bar_sqrt = alphas_cumprod.sqrt() | ||
|
||
# Store old values. | ||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | ||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | ||
|
||
# Shift so the last timestep is zero. | ||
alphas_bar_sqrt -= alphas_bar_sqrt_T | ||
|
||
# Scale so the first timestep is back to the old value. | ||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | ||
|
||
# Convert alphas_bar_sqrt to betas | ||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | ||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | ||
alphas = torch.cat([alphas_bar[0:1], alphas]) | ||
betas = 1 - alphas | ||
|
||
return betas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about creating a separate Util module for this function? If we want to use it in more than one place, it would be cleaner to just import the function to avoid duplicate code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the # copied from ...
mechanism to deal with the maintenance there. This is mainly because we want to keep the code as self-consistent as possible so that there's minimal number of redirects.
But @patrickvonplaten can share more.
https://tripleback.net/public/woman_trailing_true_stacked.png (125MiB) and here's what it looks like when you stack all of the possible CFG configs row by row. each row is a 1.0 jump in CFG. I included 0 and 1 just for completeness even though they don't change. Each column is from 0.0 to 1.0 rescaling CFG. the 'true' in the URI indicates it is applying the rescaled betas. |
https://tripleback.net/public/woman_trailing_false_stacked.png (128MiB) same settings layout, but this is "trailing" without patching betas. notice we get "less useful output", where the number of broken/noisy images are higher and requires higher CFG to get the same result, with less robust response to rescaling CFG. |
https://tripleback.net/public/woman_leading_true_stacked.png here is with the leading configuration. |
https://tripleback.net/public/woman_leading_false_stacked.png leading without patching betas. |
@@ -51,6 +51,23 @@ | |||
""" | |||
|
|||
|
|||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_pred | |||
def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is quite a weird design. Let's move classifier free guidance calculation also into this function?
Because currently argument noise_pred
is confusing. Does it mean unconditional prediction or prediction after cfg (latter is correct but the name kinda imply the prior).
So let's move cfg also into this function.
def classifier_free_guidance(pred_pos, pred_neg, guidance_weight, guidance_rescale=0):
# Apply classifier-free guidance.
pred_cfg = pred_neg + guidance_weight * (pred_pos - pred_neg)
# Apply guidance rescale. From paper [Common Diffusion Noise Schedules
# and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) section 3.4.
if guidance_rescale != 0:
std_pos = pred_pos.std(dim=list(range(1, pred_pos.ndim)), keepdim=True)
std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
# Fuse equation 15,16 for more efficient computation.
pred_cfg *= guidance_rescale * (std_pos / std_cfg) + (1 - guidance_rescale)
return pred_cfg
We wrote equation 15,16 for simpler understanding, but the computation can be fused for more efficient computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ok, I'd like to keep them separated to better seperate classifier free guidance (which is used essentially by every diffusion pipeline) from rescaling which is a bit newer and IMO.
Happy to rename noise_pred
to pred_cfg
timesteps += self.config.steps_offset | ||
elif self.config.timestep_scaling == "trailing": | ||
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy() | ||
timesteps -= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- What's the point of doing
copy()
? - Why do we use np then convert to torch tensor? Why not just use torch.arange()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy we could remove, think keeping it in numpy as mainly just a style choice
@@ -254,9 +303,20 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
step_ratio
must not be floored to int here. It should be kept as float here for trailing
to space more accurately below. See the paper, we deliberately not use floor operation to calculate interval.
So this is tricky. Unlike leading
which always has the same integer spacing. trailing
and linspace
may have uneven integer spacing.
We need to change step()
function to also make sure it supports this.
Overrall I think we should support user define any custom sampling timestep such as: [0, 99, 499, 999] in the future. step
function should not expect even spacing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok let's move to uneven step ratio for "trailing"
- don't see a problem with this :-)
@@ -143,6 +186,8 @@ def __init__( | |||
dynamic_thresholding_ratio: float = 0.995, | |||
clip_sample_range: float = 1.0, | |||
sample_max_value: float = 1.0, | |||
timestep_scaling: str = "leading", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe sample_step_selection_mode
is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, scaling
is not very well chosen.
timestep_spacing
maybe? think we refer mostly to how the timestep values are spaced no? Also want to try not too make it too long
Now using the fine-tuned checkpoint from @bghira: Note: Default imagesVS Images with
|
…Max-We/diffusers into fix-schedulers-and-sample-steps
cc @sayakpaul could you give this a review? :-) |
@patrickvonplaten i am having difficulties implementing this properly, i think because i can't train using DDIM as the noise scheduler. i have to use DDPMScheduler or the multistep solver, neither of which support the new changes, right? that said, i've started again, training with your implementation here, rather than my own. and i have seen some really great results, stuff that you simply cannot achieve with 1.5 or 2.1 on their own. you should try with kodachrome slides trained in, now have proper black levels: |
# std_text = torch.std(noise_pred_text) | ||
# std_pred = torch.std(noise_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let' remove these comments.
This can enable the model to generate very bright and dark samples instead of limiting it to samples with | ||
medium brightness. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this potentially deprecate offset noising?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess it's worth linking to it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my own testing of zero terminal + cfg rescaling I believe so. I'm producing samples with great variety in contrast similar to offset noise.
Offset noising does not see to be stable and eventually diverges, and requires some other yet-to-be-invented control to adjust how much offset noise is used. The blog posts value of 0.1 works only for some number of steps, and I've found 0.1 to be far too high to for much beyond a short dreambooth style training.
Typically above would produce a grey gradient background as the model tries to paint an image with mean brightness 0.5
Likewise sometimes difficult to get bright white backgrounds .
To reproduce the above with offset noise it may take a few attempts at guessing appropriate constant to multiple the offset noise (0.01, 0.02, etc). Zero terminal snr appears instead to simple be stable and requires no tuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the implementation can replace offset, see section 5.3 in the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does! offset noise doesn't work as well as zero SNR.
and when you enable offset noise with zero SNR, they fight each other, and the model can't learn properly.
timestep_spacing: str = "leading", | ||
rescale_betas_zero_snr: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten this is for my understanding.
Do we not need to register these variables so that they can be accessed from the config
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The register_to_config
decorator automatically does this
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty good stuff.
I think this should be documented properly.
Ideally in the section where we have https://huggingface.co/docs/diffusers/main/en/stable_diffusion.
WDYT? @patrickvonplaten @stevhliu
Added comments, this is good to go! Nice work everybody. Thanks a lot for your comments @PeterL1n, I couldn't include all of them as stated in the comments above, but I hope the current implementation works for you. If the community starts heavily adapting the new DDIMScheduler setting, think we would be more than happy to apply it to all other schedulers as well |
Sorry for fiddling into your PR so much @Max-We & thanks a mille to get this effort started |
…face#3664) * Implement option for rescaling betas to zero terminal SNR * Implement rescale classifier free guidance in pipeline_stable_diffusion.py * focus on DDIM * make style * make style * make style * make style * Apply suggestions from Peter Lin * Apply suggestions from Peter Lin * make style * Apply suggestions from code review * Apply suggestions from code review * make style * make style --------- Co-authored-by: MaxWe00 <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
…face#3664) * Implement option for rescaling betas to zero terminal SNR * Implement rescale classifier free guidance in pipeline_stable_diffusion.py * focus on DDIM * make style * make style * make style * make style * Apply suggestions from Peter Lin * Apply suggestions from Peter Lin * make style * Apply suggestions from code review * Apply suggestions from code review * make style * make style --------- Co-authored-by: MaxWe00 <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
Have these changes been propagated to other schedulers now or is it still only for DDIM ? I'm also a little bit confused about one thing: if the formulation is more exact mathematically, with results in the paper that are quite convincing, did anyone investigate why @patrickvonplaten noticed "we cannot say it's better for every model" ? It doesn't seem sense to me to make the flawed version by default and the mathematically correct opt-in ?! (i'm working myself on a extremely detail-sensitive situation where small SNR can be perceivable so I want to be sure where and on which scheduler I have to expect something coming from such issues) |
well it's been about a year since the initial paper was released and at the time we didn't have much insight. i've trained three models with this schedule, and it is rather finnicky but still hard to judge adequately. at first, the DDIM scheduler had a couple issues that prevented the samples from working correctly, as the default scheduler config for the SDXL model (and also SD 2.x/1.x) has now that the problems in DDIM were discovered while @Beinsezii ported the zsnr changes to Euler, it allowed more broad testing of the models' capabilities. in light of this, it's now obvious in hindsight that the schedule changes should have been propagated across as many of the available options as possible. it is always nice to have another option to plug in to see whether it's a problem with the weights or otherwise. zsnr is available in most schedulers now, DDIM, Euler, Euler-A, UniPC, DPM++ 2M, and DDPM. so even when training using DDPMScheduler, you can just set |
@bghira Thanks for your answer. Can you just explain to what the issue is with set_alpha_to_one ? My understanding from this paper was that I had to use timestep_spacing='trailing' and rescale_betas_zero_snr=True to get the good mathematical formulation, but you seem to be pointing to another issue ? Additionnaly, what about the param steps_offset that seem to be related ? cf doc:
EDIT : Ok so from what i've understood from reading the ~7 recent issues related to this, set_alpha_to_one=False is basically nonsense and step_offset kind of too. I do have additional questions:
I'm very confused why there is so many params with not much explanation and their interaction when the mathematical formula should be ~ exact. Thx ! |
This PR implements two of the three suggestions from the Common Diffusion Noise Schedules and Sample Steps are Flawed paper, which is discussed in #3475:
pipeline_stable_diffusion.py
Another suggestion, which is about sampling from the last timestep (3.3. in the paper) is not being implemented in this PR currently, but may be done if it's desired.
Suggestion 2 (Rescale Classifier-Free Guidance is only implemented for
pipeline_stable_diffusion.py
but can be done for others as well if desired.