diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py index 980adf273763..c7dff9eeefca 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -528,15 +528,12 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 9bb68c1d3ec9..206c3436bb3d 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -716,15 +716,12 @@ def check_source_inputs( f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 5a37886e2289..316d657e8dd6 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -434,7 +434,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index dfbb8af82dc3..27213c0af1a9 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -768,10 +768,14 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 4780a341354c..1c8ec1090b54 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -1011,10 +1011,14 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 96962c315eae..22aeba57501d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -543,7 +543,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index f0a417f83c6d..3113d61a9424 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -961,10 +961,14 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index e98f37ca6aa5..f490da07abd1 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -669,7 +669,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index a005b4f20c59..7f0ef9029443 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -367,7 +367,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 01c202414806..81009b4fe7ce 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -467,7 +467,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index afe2d1456e60..62d0caa4e2e1 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -562,7 +562,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 2de5a2913cb5..14ab01da744f 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -468,7 +468,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 6390a187f216..af50fcf54f99 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -494,7 +494,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 32638000d684..888ba311a901 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -469,7 +469,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9f759683b5e0..c8abade3e534 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -461,7 +461,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index e3f0db8494dc..a844bcaec360 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -862,10 +862,14 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e4e97d7bfc83..dec62e7e46ae 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -29,6 +29,7 @@ AutoencoderKL, DDIMScheduler, DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, LCMScheduler, LMSDiscreteScheduler, PNDMScheduler, @@ -557,6 +558,29 @@ def test_stable_diffusion_inpaint_2_images(self): image_slice2 = images[1, -3:, -3:, -1] assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2 + def test_stable_diffusion_inpaint_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device, output_pil=False) + half_dim = inputs["image"].shape[2] // 2 + inputs["mask_image"][0, 0, :half_dim, :half_dim] = 0 + + inputs["num_inference_steps"] = 4 + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [[0.6387283, 0.5564158, 0.58631873, 0.5539942, 0.5494673, 0.6461868, 0.5251618, 0.5497595, 0.5508756]] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 + @slow @require_torch_gpu