diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx index e922072e4e31..7162626ebbde 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx @@ -71,6 +71,64 @@ image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0] image.save("astronaut.png") ``` +#### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed": + +The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)** +claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. + +The abstract reads as follows: + +*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR), +and some implementations of diffusion samplers do not start from the last timestep. +Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference. +We show that the flawed design causes real problems in existing implementations. +In Stable Diffusion, it severely limits the model to only generate images with medium brightness and +prevents it from generating very bright and dark samples. We propose a few simple fixes: +- (1) rescale the noise schedule to enforce zero terminal SNR; +- (2) train the model with v prediction; +- (3) change the sampler to always start from the last timestep; +- (4) rescale classifier-free guidance to prevent over-exposure. +These simple changes ensure the diffusion process is congruent between training and inference and +allow the model to generate samples more faithful to the original data distribution.* + +You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]: +- (1) rescale the noise schedule to enforce zero terminal SNR; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True) +``` +- (2) train the model with v prediction; +Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) +and `--prediction_type="v_prediction"`. +- (3) change the sampler to always start from the last timestep; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing") +``` +- (4) rescale classifier-free guidance to prevent over-exposure. +```py +pipe(..., guidance_rescale=0.7) +``` + +An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2) +which has been fine-tuned using the `"v_prediction"`. + +The checkpoint can then be run in inference as follows: + +```py +from diffusers import DiffusionPipeline, DDIMScheduler + +pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16) +pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing" +) +pipe.to("cuda") + +prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" +image = pipeline(prompt, guidance_rescale=0.7).images[0] +``` + +## DDIMScheduler +[[autodoc]] DDIMScheduler + ### Image Inpainting - *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`] diff --git a/docs/source/en/api/schedulers/ddim.mdx b/docs/source/en/api/schedulers/ddim.mdx index 51b0cc3e9a09..0db5e4f4e2b5 100644 --- a/docs/source/en/api/schedulers/ddim.mdx +++ b/docs/source/en/api/schedulers/ddim.mdx @@ -18,10 +18,71 @@ specific language governing permissions and limitations under the License. The abstract of the paper is the following: -Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space. +*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, +yet they require simulating a Markov chain for many steps to produce a sample. +To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models +with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. +We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. +We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off +computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.* The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim). For questions, feel free to contact the author on [tsong.me](https://tsong.me/). +### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed": + +The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)** +claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. + +The abstract reads as follows: + +*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR), +and some implementations of diffusion samplers do not start from the last timestep. +Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference. +We show that the flawed design causes real problems in existing implementations. +In Stable Diffusion, it severely limits the model to only generate images with medium brightness and +prevents it from generating very bright and dark samples. We propose a few simple fixes: +- (1) rescale the noise schedule to enforce zero terminal SNR; +- (2) train the model with v prediction; +- (3) change the sampler to always start from the last timestep; +- (4) rescale classifier-free guidance to prevent over-exposure. +These simple changes ensure the diffusion process is congruent between training and inference and +allow the model to generate samples more faithful to the original data distribution.* + +You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]: +- (1) rescale the noise schedule to enforce zero terminal SNR; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True) +``` +- (2) train the model with v prediction; +Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) +and `--prediction_type="v_prediction"`. +- (3) change the sampler to always start from the last timestep; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing") +``` +- (4) rescale classifier-free guidance to prevent over-exposure. +```py +pipe(..., guidance_rescale=0.7) +``` + +An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2) +which has been fine-tuned using the `"v_prediction"`. + +The checkpoint can then be run in inference as follows: + +```py +from diffusers import DiffusionPipeline, DDIMScheduler + +pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16) +pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing" +) +pipe.to("cuda") + +prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" +image = pipeline(prompt, guidance_rescale=0.7).images[0] +``` + ## DDIMScheduler [[autodoc]] DDIMScheduler diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index bbf7bf9b85bb..0965c77eea96 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -307,6 +307,12 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -848,6 +854,10 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 806637f04c53..30d527efd22d 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -272,6 +272,12 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -749,6 +755,10 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 64ca06a53a7b..b79e4f72144b 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -51,6 +51,21 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" @@ -567,6 +582,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -627,6 +643,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: @@ -717,6 +738,10 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f7374452a5f6..8368668ebea7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -55,6 +55,20 @@ """ +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -568,6 +582,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -628,6 +643,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: @@ -718,6 +738,10 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6b62d8893482..bab6f8acea03 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -76,6 +76,42 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising @@ -122,6 +158,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -143,6 +187,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -159,6 +205,10 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -251,12 +301,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + + # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.config.steps_offset def step( self, diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 87a960c7d1a4..33cc7f638ec2 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -208,6 +208,27 @@ def test_stable_diffusion_k_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_unflawed(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = DDIMScheduler.from_config( + components["scheduler"].config, timestep_spacing="trailing" + ) + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_rescale"] = 0.7 + inputs["num_inference_steps"] = 10 + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4736, 0.5405, 0.4705, 0.4955, 0.5675, 0.4812, 0.5310, 0.4967, 0.5064]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_long_prompt(self): components = self.get_dummy_components() components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index d1a2c856659f..21862ba6a216 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -384,6 +384,29 @@ def test_stable_diffusion_text2img_pipeline_v_pred_default(self): assert image.shape == (768, 768, 3) assert np.abs(expected_image - image).max() < 9e-1 + def test_stable_diffusion_text2img_pipeline_unflawed(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/lion_galaxy.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") + pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, timestep_scaling="trailing", rescale_betas_zero_snr=True + ) + pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + generator = torch.manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, guidance_rescale=0.7, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-1 + def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self): expected_image = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index e9c85314d558..156b02b2208e 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -69,6 +69,14 @@ def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) + def test_timestep_spacing(self): + for timestep_spacing in ["trailing", "leading"]: + self.check_over_configs(timestep_spacing=timestep_spacing) + + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_thresholding(self): self.check_over_configs(thresholding=False) for threshold in [0.5, 1.0, 2.0]: