From 30f4d80337d072a940e3be94f8040e1da9d754a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 8 May 2024 09:31:17 +0300 Subject: [PATCH] Fix `latents.dtype` before `vae.decode()` at ROCm Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com> --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 2 ++ .../pipelines/controlnet/pipeline_controlnet_img2img.py | 2 ++ .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 2 ++ .../pipelines/controlnet_xs/pipeline_controlnet_xs.py | 2 ++ .../stable_diffusion_variants/pipeline_cycle_diffusion.py | 2 ++ .../pipeline_stable_diffusion_inpaint_legacy.py | 2 ++ .../pipeline_stable_diffusion_model_editing.py | 2 ++ .../pipeline_stable_diffusion_pix2pix_zero.py | 2 ++ .../pipeline_versatile_diffusion_dual_guided.py | 2 ++ .../pipeline_versatile_diffusion_image_variation.py | 2 ++ .../pipeline_versatile_diffusion_text_to_image.py | 2 ++ .../pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py | 2 ++ .../pipelines/paint_by_example/pipeline_paint_by_example.py | 2 ++ .../pipeline_semantic_stable_diffusion.py | 2 ++ .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_depth2img.py | 2 ++ .../pipeline_stable_diffusion_image_variation.py | 2 ++ .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 ++ .../pipeline_stable_diffusion_instruct_pix2pix.py | 2 ++ .../pipeline_stable_diffusion_latent_upscale.py | 2 ++ .../stable_diffusion/pipeline_stable_diffusion_upscale.py | 2 ++ .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 2 ++ .../stable_diffusion/pipeline_stable_unclip_img2img.py | 2 ++ .../pipeline_stable_diffusion_attend_and_excite.py | 2 ++ .../pipeline_stable_diffusion_diffedit.py | 2 ++ .../pipeline_stable_diffusion_k_diffusion.py | 2 ++ .../pipeline_stable_diffusion_panorama.py | 2 ++ .../stable_diffusion_safe/pipeline_stable_diffusion_safe.py | 2 ++ .../stable_diffusion_sag/pipeline_stable_diffusion_sag.py | 2 ++ .../t2i_adapter/pipeline_stable_diffusion_adapter.py | 2 ++ 30 files changed, 62 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index f099d54e57cd..9a4f958ff360 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -551,6 +551,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 022f30d819d8..0536da80c10a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -544,6 +544,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 47dbb26eb3e4..80ff79925172 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -669,6 +669,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 622bac8c5f97..f07683234d50 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -416,6 +416,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py index 0581effef2fe..464e1532b6cb 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -552,6 +552,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py index c7dff9eeefca..426411bcae9e 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -453,6 +453,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py index f44a1ca74ee4..aebfdcf362cc 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -391,6 +391,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py index 5f744578810b..d1af13b1b3cd 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py @@ -605,6 +605,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index b1117044cf18..fb1f43568b46 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -305,6 +305,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 59aa370ec2f6..82a838ce4388 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -165,6 +165,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 0c76e5837b99..5cf043838f9d 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -222,6 +222,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py index 619be13a8f36..db2961ba72c0 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -406,6 +406,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 263d507bbc75..a81e411c6931 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -252,6 +252,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index fe83a860aeac..c2e02cca0b1c 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -112,6 +112,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ff933237a6ab..3d523767e191 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -559,6 +559,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -1015,6 +1017,8 @@ def __call__( callback(step_idx, t, latents) if not output_type == "latent": + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 1f822971568f..b3362f70c28e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -384,6 +384,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index c300c7a2f3f4..b92eae30dfec 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -176,6 +176,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 1b31c099b177..c909f3fcca3b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -605,6 +605,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 0bf5a92a4fcc..780311e17bf5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -736,6 +736,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 918dffe5199d..3d1305a162f1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -201,6 +201,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 2d04cf41d9b5..b2298daca4ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -407,6 +407,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 02ddc65c7111..098db40792ac 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -470,6 +470,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 134ec39effc5..261463a7439c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -434,6 +434,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py index 2adcb0a8c0a1..e9a046acb42d 100644 --- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -482,6 +482,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index e89d7f77e8b1..228340052e05 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -627,6 +627,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py index e2096be7e894..a34b1a89b7b8 100755 --- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -382,6 +382,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 514c8643ba36..e98f5f9d1f26 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -516,6 +516,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 63b8c6108ac4..6bbb1d25ca54 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -337,6 +337,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index cb29ce386f32..1b4b3e27ad8d 100644 --- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -456,6 +456,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index fd5b1bb454b7..c4c7b9fd43c4 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -484,6 +484,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16