Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,16 @@ def __call__(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

if callback_on_step_end is not None:
callback_kwargs = {}
Expand Down Expand Up @@ -1228,6 +1237,14 @@ def __call__(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1370,7 +1370,16 @@ def denoising_value_valid(dnv):
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

if callback_on_step_end is not None:
callback_kwargs = {}
Expand Down Expand Up @@ -1405,6 +1414,14 @@ def denoising_value_valid(dnv):
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1720,7 +1720,16 @@ def denoising_value_valid(dnv):
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

if num_channels_unet == 4:
init_latents_proper = image_latents
Expand Down Expand Up @@ -1772,6 +1781,14 @@ def denoising_value_valid(dnv):
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,16 @@ def __call__(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand All @@ -937,6 +946,14 @@ def __call__(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
Expand Down