Skip to content

Commit cd77a03

Browse files
authored
[CLIPGuidedStableDiffusion] support DDIM scheduler (#1190)
add ddim in clip guided
1 parent 663f0c1 commit cd77a03

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

examples/community/clip_guided_stable_diffusion.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from torch import nn
66
from torch.nn import functional as F
77

8-
from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
8+
from diffusers import (
9+
AutoencoderKL,
10+
DDIMScheduler,
11+
DiffusionPipeline,
12+
LMSDiscreteScheduler,
13+
PNDMScheduler,
14+
UNet2DConditionModel,
15+
)
916
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
1017
from torchvision import transforms
1118
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
@@ -56,7 +63,7 @@ def __init__(
5663
clip_model: CLIPModel,
5764
tokenizer: CLIPTokenizer,
5865
unet: UNet2DConditionModel,
59-
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
66+
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
6067
feature_extractor: CLIPFeatureExtractor,
6168
):
6269
super().__init__()
@@ -123,7 +130,7 @@ def cond_fn(
123130
# predict the noise residual
124131
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
125132

126-
if isinstance(self.scheduler, PNDMScheduler):
133+
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
127134
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
128135
beta_prod_t = 1 - alpha_prod_t
129136
# compute predicted original sample from predicted noise also called
@@ -176,6 +183,7 @@ def __call__(
176183
num_inference_steps: Optional[int] = 50,
177184
guidance_scale: Optional[float] = 7.5,
178185
num_images_per_prompt: Optional[int] = 1,
186+
eta: float = 0.0,
179187
clip_guidance_scale: Optional[float] = 100,
180188
clip_prompt: Optional[Union[str, List[str]]] = None,
181189
num_cutouts: Optional[int] = 4,
@@ -275,6 +283,20 @@ def __call__(
275283
# scale the initial noise by the standard deviation required by the scheduler
276284
latents = latents * self.scheduler.init_noise_sigma
277285

286+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
287+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
288+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
289+
# and should be between [0, 1]
290+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
291+
extra_step_kwargs = {}
292+
if accepts_eta:
293+
extra_step_kwargs["eta"] = eta
294+
295+
# check if the scheduler accepts generator
296+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
297+
if accepts_generator:
298+
extra_step_kwargs["generator"] = generator
299+
278300
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
279301
# expand the latents if we are doing classifier free guidance
280302
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -306,7 +328,7 @@ def __call__(
306328
)
307329

308330
# compute the previous noisy sample x_t -> x_t-1
309-
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
331+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
310332

311333
# scale and decode the image latents with vae
312334
latents = 1 / 0.18215 * latents

0 commit comments

Comments
 (0)