Skip to content

Commit 12a232e

Browse files
Max-WeMaxWe00patrickvonplaten
authored
Fix schedulers zero SNR and rescale classifier free guidance (#3664)
* Implement option for rescaling betas to zero terminal SNR * Implement rescale classifier free guidance in pipeline_stable_diffusion.py * focus on DDIM * make style * make style * make style * make style * Apply suggestions from Peter Lin * Apply suggestions from Peter Lin * make style * Apply suggestions from code review * Apply suggestions from code review * make style * make style --------- Co-authored-by: MaxWe00 <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 74fd735 commit 12a232e

File tree

10 files changed

+310
-6
lines changed

10 files changed

+310
-6
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,64 @@ image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]
7171
image.save("astronaut.png")
7272
```
7373

74+
#### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed":
75+
76+
The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)**
77+
claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion.
78+
79+
The abstract reads as follows:
80+
81+
*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR),
82+
and some implementations of diffusion samplers do not start from the last timestep.
83+
Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference.
84+
We show that the flawed design causes real problems in existing implementations.
85+
In Stable Diffusion, it severely limits the model to only generate images with medium brightness and
86+
prevents it from generating very bright and dark samples. We propose a few simple fixes:
87+
- (1) rescale the noise schedule to enforce zero terminal SNR;
88+
- (2) train the model with v prediction;
89+
- (3) change the sampler to always start from the last timestep;
90+
- (4) rescale classifier-free guidance to prevent over-exposure.
91+
These simple changes ensure the diffusion process is congruent between training and inference and
92+
allow the model to generate samples more faithful to the original data distribution.*
93+
94+
You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]:
95+
- (1) rescale the noise schedule to enforce zero terminal SNR;
96+
```py
97+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
98+
```
99+
- (2) train the model with v prediction;
100+
Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)
101+
and `--prediction_type="v_prediction"`.
102+
- (3) change the sampler to always start from the last timestep;
103+
```py
104+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing")
105+
```
106+
- (4) rescale classifier-free guidance to prevent over-exposure.
107+
```py
108+
pipe(..., guidance_rescale=0.7)
109+
```
110+
111+
An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2)
112+
which has been fine-tuned using the `"v_prediction"`.
113+
114+
The checkpoint can then be run in inference as follows:
115+
116+
```py
117+
from diffusers import DiffusionPipeline, DDIMScheduler
118+
119+
pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16)
120+
pipe.scheduler = DDIMScheduler.from_config(
121+
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing"
122+
)
123+
pipe.to("cuda")
124+
125+
prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
126+
image = pipeline(prompt, guidance_rescale=0.7).images[0]
127+
```
128+
129+
## DDIMScheduler
130+
[[autodoc]] DDIMScheduler
131+
74132
### Image Inpainting
75133

76134
- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]

docs/source/en/api/schedulers/ddim.mdx

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,71 @@ specific language governing permissions and limitations under the License.
1818

1919
The abstract of the paper is the following:
2020

21-
Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.
21+
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training,
22+
yet they require simulating a Markov chain for many steps to produce a sample.
23+
To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models
24+
with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process.
25+
We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from.
26+
We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off
27+
computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
2228

2329
The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim).
2430
For questions, feel free to contact the author on [tsong.me](https://tsong.me/).
2531

32+
### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed":
33+
34+
The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)**
35+
claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion.
36+
37+
The abstract reads as follows:
38+
39+
*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR),
40+
and some implementations of diffusion samplers do not start from the last timestep.
41+
Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference.
42+
We show that the flawed design causes real problems in existing implementations.
43+
In Stable Diffusion, it severely limits the model to only generate images with medium brightness and
44+
prevents it from generating very bright and dark samples. We propose a few simple fixes:
45+
- (1) rescale the noise schedule to enforce zero terminal SNR;
46+
- (2) train the model with v prediction;
47+
- (3) change the sampler to always start from the last timestep;
48+
- (4) rescale classifier-free guidance to prevent over-exposure.
49+
These simple changes ensure the diffusion process is congruent between training and inference and
50+
allow the model to generate samples more faithful to the original data distribution.*
51+
52+
You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]:
53+
- (1) rescale the noise schedule to enforce zero terminal SNR;
54+
```py
55+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
56+
```
57+
- (2) train the model with v prediction;
58+
Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)
59+
and `--prediction_type="v_prediction"`.
60+
- (3) change the sampler to always start from the last timestep;
61+
```py
62+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing")
63+
```
64+
- (4) rescale classifier-free guidance to prevent over-exposure.
65+
```py
66+
pipe(..., guidance_rescale=0.7)
67+
```
68+
69+
An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2)
70+
which has been fine-tuned using the `"v_prediction"`.
71+
72+
The checkpoint can then be run in inference as follows:
73+
74+
```py
75+
from diffusers import DiffusionPipeline, DDIMScheduler
76+
77+
pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16)
78+
pipe.scheduler = DDIMScheduler.from_config(
79+
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing"
80+
)
81+
pipe.to("cuda")
82+
83+
prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
84+
image = pipeline(prompt, guidance_rescale=0.7).images[0]
85+
```
86+
2687
## DDIMScheduler
2788
[[autodoc]] DDIMScheduler

examples/text_to_image/train_text_to_image.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ def parse_args():
307307
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
308308
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
309309
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
310+
parser.add_argument(
311+
"--prediction_type",
312+
type=str,
313+
default=None,
314+
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
315+
)
310316
parser.add_argument(
311317
"--hub_model_id",
312318
type=str,
@@ -848,6 +854,10 @@ def collate_fn(examples):
848854
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
849855

850856
# Get the target for loss depending on the prediction type
857+
if args.prediction_type is not None:
858+
# set prediction_type of scheduler if defined
859+
noise_scheduler.register_to_config(prediction_type=args.prediction_type)
860+
851861
if noise_scheduler.config.prediction_type == "epsilon":
852862
target = noise
853863
elif noise_scheduler.config.prediction_type == "v_prediction":

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ def parse_args():
272272
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
273273
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
274274
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
275+
parser.add_argument(
276+
"--prediction_type",
277+
type=str,
278+
default=None,
279+
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
280+
)
275281
parser.add_argument(
276282
"--hub_model_id",
277283
type=str,
@@ -749,6 +755,10 @@ def collate_fn(examples):
749755
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
750756

751757
# Get the target for loss depending on the prediction type
758+
if args.prediction_type is not None:
759+
# set prediction_type of scheduler if defined
760+
noise_scheduler.register_to_config(prediction_type=args.prediction_type)
761+
752762
if noise_scheduler.config.prediction_type == "epsilon":
753763
target = noise
754764
elif noise_scheduler.config.prediction_type == "v_prediction":

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@
5151
"""
5252

5353

54+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
55+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
56+
"""
57+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
58+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
59+
"""
60+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
61+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
62+
# rescale the results from guidance (fixes overexposure)
63+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
64+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
65+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
66+
return noise_cfg
67+
68+
5469
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
5570
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
5671
r"""
@@ -567,6 +582,7 @@ def __call__(
567582
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
568583
callback_steps: int = 1,
569584
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
585+
guidance_rescale: float = 0.0,
570586
):
571587
r"""
572588
Function invoked when calling the pipeline for generation.
@@ -627,6 +643,11 @@ def __call__(
627643
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
628644
`self.processor` in
629645
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
646+
guidance_rescale (`float`, *optional*, defaults to 0.7):
647+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
648+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
649+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
650+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
630651
631652
Examples:
632653
@@ -717,6 +738,10 @@ def __call__(
717738
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
718739
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
719740

741+
if do_classifier_free_guidance and guidance_rescale > 0.0:
742+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
743+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
744+
720745
# compute the previous noisy sample x_t -> x_t-1
721746
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
722747

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@
5555
"""
5656

5757

58+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
59+
"""
60+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
61+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
62+
"""
63+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
64+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65+
# rescale the results from guidance (fixes overexposure)
66+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
69+
return noise_cfg
70+
71+
5872
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
5973
r"""
6074
Pipeline for text-to-image generation using Stable Diffusion.
@@ -568,6 +582,7 @@ def __call__(
568582
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
569583
callback_steps: int = 1,
570584
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
585+
guidance_rescale: float = 0.0,
571586
):
572587
r"""
573588
Function invoked when calling the pipeline for generation.
@@ -628,6 +643,11 @@ def __call__(
628643
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
629644
`self.processor` in
630645
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
646+
guidance_rescale (`float`, *optional*, defaults to 0.7):
647+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
648+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
649+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
650+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
631651
632652
Examples:
633653
@@ -718,6 +738,10 @@ def __call__(
718738
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
719739
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
720740

741+
if do_classifier_free_guidance and guidance_rescale > 0.0:
742+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
743+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
744+
721745
# compute the previous noisy sample x_t -> x_t-1
722746
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
723747

0 commit comments

Comments
 (0)