|
5 | 5 | from torch import nn |
6 | 6 | from torch.nn import functional as F |
7 | 7 |
|
8 | | -from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
| 8 | +from diffusers import ( |
| 9 | + AutoencoderKL, |
| 10 | + DDIMScheduler, |
| 11 | + DiffusionPipeline, |
| 12 | + LMSDiscreteScheduler, |
| 13 | + PNDMScheduler, |
| 14 | + UNet2DConditionModel, |
| 15 | +) |
9 | 16 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
10 | 17 | from torchvision import transforms |
11 | 18 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer |
@@ -56,7 +63,7 @@ def __init__( |
56 | 63 | clip_model: CLIPModel, |
57 | 64 | tokenizer: CLIPTokenizer, |
58 | 65 | unet: UNet2DConditionModel, |
59 | | - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], |
| 66 | + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], |
60 | 67 | feature_extractor: CLIPFeatureExtractor, |
61 | 68 | ): |
62 | 69 | super().__init__() |
@@ -123,7 +130,7 @@ def cond_fn( |
123 | 130 | # predict the noise residual |
124 | 131 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample |
125 | 132 |
|
126 | | - if isinstance(self.scheduler, PNDMScheduler): |
| 133 | + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): |
127 | 134 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] |
128 | 135 | beta_prod_t = 1 - alpha_prod_t |
129 | 136 | # compute predicted original sample from predicted noise also called |
@@ -176,6 +183,7 @@ def __call__( |
176 | 183 | num_inference_steps: Optional[int] = 50, |
177 | 184 | guidance_scale: Optional[float] = 7.5, |
178 | 185 | num_images_per_prompt: Optional[int] = 1, |
| 186 | + eta: float = 0.0, |
179 | 187 | clip_guidance_scale: Optional[float] = 100, |
180 | 188 | clip_prompt: Optional[Union[str, List[str]]] = None, |
181 | 189 | num_cutouts: Optional[int] = 4, |
@@ -275,6 +283,20 @@ def __call__( |
275 | 283 | # scale the initial noise by the standard deviation required by the scheduler |
276 | 284 | latents = latents * self.scheduler.init_noise_sigma |
277 | 285 |
|
| 286 | + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 287 | + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| 288 | + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| 289 | + # and should be between [0, 1] |
| 290 | + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| 291 | + extra_step_kwargs = {} |
| 292 | + if accepts_eta: |
| 293 | + extra_step_kwargs["eta"] = eta |
| 294 | + |
| 295 | + # check if the scheduler accepts generator |
| 296 | + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| 297 | + if accepts_generator: |
| 298 | + extra_step_kwargs["generator"] = generator |
| 299 | + |
278 | 300 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
279 | 301 | # expand the latents if we are doing classifier free guidance |
280 | 302 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
@@ -306,7 +328,7 @@ def __call__( |
306 | 328 | ) |
307 | 329 |
|
308 | 330 | # compute the previous noisy sample x_t -> x_t-1 |
309 | | - latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
| 331 | + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
310 | 332 |
|
311 | 333 | # scale and decode the image latents with vae |
312 | 334 | latents = 1 / 0.18215 * latents |
|
0 commit comments