diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py index cfab70926a4a..5ea7c2c14551 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -1419,7 +1419,6 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode=" if needs_upcasting: image = image.float() self.upcast_vae() - image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) x0 = self.vae.encode(image).latent_dist.mode() x0 = x0.to(dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 5e7be370be01..d9380020b329 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -525,8 +525,8 @@ def prepare_image_latents( # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: + image = image.float() self.upcast_vae() - image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")