Skip to content

Fix schedulers zero SNR and rescale classifier free guidance #3664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,64 @@ image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]
image.save("astronaut.png")
```

#### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed":

The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)**
claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion.

The abstract reads as follows:

*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR),
and some implementations of diffusion samplers do not start from the last timestep.
Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference.
We show that the flawed design causes real problems in existing implementations.
In Stable Diffusion, it severely limits the model to only generate images with medium brightness and
prevents it from generating very bright and dark samples. We propose a few simple fixes:
- (1) rescale the noise schedule to enforce zero terminal SNR;
- (2) train the model with v prediction;
- (3) change the sampler to always start from the last timestep;
- (4) rescale classifier-free guidance to prevent over-exposure.
These simple changes ensure the diffusion process is congruent between training and inference and
allow the model to generate samples more faithful to the original data distribution.*

You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]:
- (1) rescale the noise schedule to enforce zero terminal SNR;
```py
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
```
- (2) train the model with v prediction;
Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)
and `--prediction_type="v_prediction"`.
- (3) change the sampler to always start from the last timestep;
```py
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing")
```
- (4) rescale classifier-free guidance to prevent over-exposure.
```py
pipe(..., guidance_rescale=0.7)
```

An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2)
which has been fine-tuned using the `"v_prediction"`.

The checkpoint can then be run in inference as follows:

```py
from diffusers import DiffusionPipeline, DDIMScheduler

pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing"
)
pipe.to("cuda")

prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
image = pipeline(prompt, guidance_rescale=0.7).images[0]
```

## DDIMScheduler
[[autodoc]] DDIMScheduler

### Image Inpainting

- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]
Expand Down
63 changes: 62 additions & 1 deletion docs/source/en/api/schedulers/ddim.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,71 @@ specific language governing permissions and limitations under the License.

The abstract of the paper is the following:

Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training,
yet they require simulating a Markov chain for many steps to produce a sample.
To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models
with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process.
We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from.
We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off
computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*

The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim).
For questions, feel free to contact the author on [tsong.me](https://tsong.me/).

### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed":

The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)**
claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion.

The abstract reads as follows:

*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR),
and some implementations of diffusion samplers do not start from the last timestep.
Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference.
We show that the flawed design causes real problems in existing implementations.
In Stable Diffusion, it severely limits the model to only generate images with medium brightness and
prevents it from generating very bright and dark samples. We propose a few simple fixes:
- (1) rescale the noise schedule to enforce zero terminal SNR;
- (2) train the model with v prediction;
- (3) change the sampler to always start from the last timestep;
- (4) rescale classifier-free guidance to prevent over-exposure.
These simple changes ensure the diffusion process is congruent between training and inference and
allow the model to generate samples more faithful to the original data distribution.*

You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]:
- (1) rescale the noise schedule to enforce zero terminal SNR;
```py
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
```
- (2) train the model with v prediction;
Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)
and `--prediction_type="v_prediction"`.
- (3) change the sampler to always start from the last timestep;
```py
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing")
```
- (4) rescale classifier-free guidance to prevent over-exposure.
```py
pipe(..., guidance_rescale=0.7)
```

An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2)
which has been fine-tuned using the `"v_prediction"`.

The checkpoint can then be run in inference as follows:

```py
from diffusers import DiffusionPipeline, DDIMScheduler

pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing"
)
pipe.to("cuda")

prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
image = pipeline(prompt, guidance_rescale=0.7).images[0]
```

## DDIMScheduler
[[autodoc]] DDIMScheduler
10 changes: 10 additions & 0 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def parse_args():
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--prediction_type",
type=str,
default=None,
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
)
parser.add_argument(
"--hub_model_id",
type=str,
Expand Down Expand Up @@ -848,6 +854,10 @@ def collate_fn(examples):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
# set prediction_type of scheduler if defined
noise_scheduler.register_to_config(prediction_type=args.prediction_type)

if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
Expand Down
10 changes: 10 additions & 0 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ def parse_args():
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--prediction_type",
type=str,
default=None,
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
)
parser.add_argument(
"--hub_model_id",
type=str,
Expand Down Expand Up @@ -749,6 +755,10 @@ def collate_fn(examples):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
# set prediction_type of scheduler if defined
noise_scheduler.register_to_config(prediction_type=args.prediction_type)

if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
Expand Down
25 changes: 25 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Expand Down Expand Up @@ -567,6 +582,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -627,6 +643,11 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.

Examples:

Expand Down Expand Up @@ -717,6 +738,10 @@ def __call__(
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@
"""


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg


class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
Expand Down Expand Up @@ -568,6 +582,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -628,6 +643,11 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.

Examples:

Expand Down Expand Up @@ -718,6 +738,10 @@ def __call__(
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Expand Down
Loading