|
13 | 13 | AutoencoderKL,
|
14 | 14 | DDIMScheduler,
|
15 | 15 | DiffusionPipeline,
|
| 16 | + DPMSolverMultistepScheduler, |
16 | 17 | LMSDiscreteScheduler,
|
17 | 18 | PNDMScheduler,
|
18 | 19 | UNet2DConditionModel,
|
@@ -140,7 +141,7 @@ def __init__(
|
140 | 141 | clip_model: CLIPModel,
|
141 | 142 | tokenizer: CLIPTokenizer,
|
142 | 143 | unet: UNet2DConditionModel,
|
143 |
| - scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], |
| 144 | + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], |
144 | 145 | feature_extractor: CLIPFeatureExtractor,
|
145 | 146 | ):
|
146 | 147 | super().__init__()
|
@@ -263,17 +264,12 @@ def cond_fn(
|
263 | 264 | ):
|
264 | 265 | latents = latents.detach().requires_grad_()
|
265 | 266 |
|
266 |
| - if isinstance(self.scheduler, LMSDiscreteScheduler): |
267 |
| - sigma = self.scheduler.sigmas[index] |
268 |
| - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
269 |
| - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) |
270 |
| - else: |
271 |
| - latent_model_input = latents |
| 267 | + latent_model_input = self.scheduler.scale_model_input(latents, timestep) |
272 | 268 |
|
273 | 269 | # predict the noise residual
|
274 | 270 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
|
275 | 271 |
|
276 |
| - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): |
| 272 | + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)): |
277 | 273 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
278 | 274 | beta_prod_t = 1 - alpha_prod_t
|
279 | 275 | # compute predicted original sample from predicted noise also called
|
|
0 commit comments