Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,12 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,15 +716,12 @@ def check_source_inputs(
f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)

return timesteps, num_inference_steps - t_start

Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_consistency_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,10 +768,14 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

# begin_index is None when the scheduler is used for training
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,10 +1011,14 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

# begin_index is None when the scheduler is used for training
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,14 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

# begin_index is None when the scheduler is used for training
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_heun_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,11 @@ def add_noise(
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,10 +862,14 @@ def add_noise(
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

# begin_index is None when the scheduler is used for training
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called bevore first denoising step to create inital latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
Expand Down
24 changes: 24 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
Expand Down Expand Up @@ -557,6 +558,29 @@ def test_stable_diffusion_inpaint_2_images(self):
image_slice2 = images[1, -3:, -3:, -1]
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2

def test_stable_diffusion_inpaint_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device, output_pil=False)
half_dim = inputs["image"].shape[2] // 2
inputs["mask_image"][0, 0, :half_dim, :half_dim] = 0

inputs["num_inference_steps"] = 4
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)

expected_slice = np.array(
[[0.6387283, 0.5564158, 0.58631873, 0.5539942, 0.5494673, 0.6461868, 0.5251618, 0.5497595, 0.5508756]]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4


@slow
@require_torch_gpu
Expand Down