Skip to content

Commit bc7c7e8

Browse files
author
bghira
committed
mps: fix XL pipeline inference at training time due to upstream pytorch bug
1 parent 8244146 commit bc7c7e8

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,11 @@ def __call__(
11931193
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
11941194

11951195
# compute the previous noisy sample x_t -> x_t-1
1196+
latents_dtype = latents.dtype
11961197
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1198+
if torch.backends.mps.is_available() and latents.dtype != latents_dtype:
1199+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1200+
latents = latents.to(latents_dtype)
11971201

11981202
if callback_on_step_end is not None:
11991203
callback_kwargs = {}
@@ -1228,6 +1232,9 @@ def __call__(
12281232
if needs_upcasting:
12291233
self.upcast_vae()
12301234
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1235+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
1236+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1237+
self.vae = self.vae.to(latents.dtype)
12311238

12321239
# unscale/denormalize the latents
12331240
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,11 @@ def denoising_value_valid(dnv):
13701370
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
13711371

13721372
# compute the previous noisy sample x_t -> x_t-1
1373+
latents_dtype = latents.dtype
13731374
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1375+
if torch.backends.mps.is_available() and latents.dtype != latents_dtype:
1376+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1377+
latents = latents.to(latents_dtype)
13741378

13751379
if callback_on_step_end is not None:
13761380
callback_kwargs = {}
@@ -1405,6 +1409,9 @@ def denoising_value_valid(dnv):
14051409
if needs_upcasting:
14061410
self.upcast_vae()
14071411
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1412+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
1413+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1414+
self.vae = self.vae.to(latents.dtype)
14081415

14091416
# unscale/denormalize the latents
14101417
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,11 @@ def denoising_value_valid(dnv):
17201720
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
17211721

17221722
# compute the previous noisy sample x_t -> x_t-1
1723+
latents_dtype = latents.dtype
17231724
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1725+
if torch.backends.mps.is_available() and latents.dtype != latents_dtype:
1726+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1727+
latents = latents.to(latents_dtype)
17241728

17251729
if num_channels_unet == 4:
17261730
init_latents_proper = image_latents
@@ -1772,6 +1776,9 @@ def denoising_value_valid(dnv):
17721776
if needs_upcasting:
17731777
self.upcast_vae()
17741778
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1779+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
1780+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1781+
self.vae = self.vae.to(latents.dtype)
17751782

17761783
# unscale/denormalize the latents
17771784
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,11 @@ def __call__(
918918
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
919919

920920
# compute the previous noisy sample x_t -> x_t-1
921+
latents_dtype = latents.dtype
921922
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
923+
if torch.backends.mps.is_available() and latents.dtype != latents_dtype:
924+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
925+
latents = latents.to(latents_dtype)
922926

923927
# call the callback, if provided
924928
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -937,6 +941,9 @@ def __call__(
937941
if needs_upcasting:
938942
self.upcast_vae()
939943
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
944+
elif torch.backends.mps.is_available() and latents.dtype != self.vae.dtype:
945+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
946+
self.vae = self.vae.to(latents.dtype)
940947

941948
# unscale/denormalize the latents
942949
# denormalize with the mean and std if available and not None

0 commit comments

Comments
 (0)