Skip to content

Commit ec67986

Browse files
author
bghira
committed
diffusers#7426 fix stable diffusion xl inference on MPS when dtypes shift unexpectedly due to pytorch bugs
1 parent 8244146 commit ec67986

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def encode_prompt(
349349
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
350350
)
351351

352-
if prompt_embeds is None:
352+
if prompt_embeds is None:
353353
prompt_2 = prompt_2 or prompt
354354
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
355355

@@ -1193,7 +1193,11 @@ def __call__(
11931193
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
11941194

11951195
# compute the previous noisy sample x_t -> x_t-1
1196+
old_dtype = latents.dtype
11961197
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1198+
if latents.dtype != old_dtype:
1199+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
1200+
latents = latents.to(old_dtype)
11971201

11981202
if callback_on_step_end is not None:
11991203
callback_kwargs = {}
@@ -1228,6 +1232,9 @@ def __call__(
12281232
if needs_upcasting:
12291233
self.upcast_vae()
12301234
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1235+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
1236+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
1237+
self.vae = self.vae.to(latents.dtype)
12311238

12321239
# unscale/denormalize the latents
12331240
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,11 @@ def denoising_value_valid(dnv):
13701370
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
13711371

13721372
# compute the previous noisy sample x_t -> x_t-1
1373+
old_dtype = latents.dtype
13731374
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1375+
if latents.dtype != old_dtype:
1376+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
1377+
latents = latents.to(old_dtype)
13741378

13751379
if callback_on_step_end is not None:
13761380
callback_kwargs = {}
@@ -1405,6 +1409,9 @@ def denoising_value_valid(dnv):
14051409
if needs_upcasting:
14061410
self.upcast_vae()
14071411
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1412+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
1413+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
1414+
self.vae = self.vae.to(latents.dtype)
14081415

14091416
# unscale/denormalize the latents
14101417
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,11 @@ def denoising_value_valid(dnv):
17201720
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
17211721

17221722
# compute the previous noisy sample x_t -> x_t-1
1723+
old_dtype = latents.dtype
17231724
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1725+
if latents.dtype != old_dtype:
1726+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
1727+
latents = latents.to(old_dtype)
17241728

17251729
if num_channels_unet == 4:
17261730
init_latents_proper = image_latents
@@ -1772,6 +1776,9 @@ def denoising_value_valid(dnv):
17721776
if needs_upcasting:
17731777
self.upcast_vae()
17741778
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1779+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
1780+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
1781+
self.vae = self.vae.to(latents.dtype)
17751782

17761783
# unscale/denormalize the latents
17771784
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,11 @@ def __call__(
918918
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
919919

920920
# compute the previous noisy sample x_t -> x_t-1
921+
old_dtype = latents.dtype
921922
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
923+
if latents.dtype != old_dtype:
924+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
925+
latents = latents.to(old_dtype)
922926

923927
# call the callback, if provided
924928
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -937,6 +941,9 @@ def __call__(
937941
if needs_upcasting:
938942
self.upcast_vae()
939943
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
944+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
945+
# some platforms (eg. apple mps) misbehave due to a pytorch bug, this is a workaround
946+
self.vae = self.vae.to(latents.dtype)
940947

941948
# unscale/denormalize the latents
942949
# denormalize with the mean and std if available and not None

0 commit comments

Comments
 (0)