diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 165906b2a643..ee8bb21eb018 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -562,6 +562,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 4cdec5b3cf5f..7eec53da0727 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -540,6 +540,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index da5a02d14108..4ac3a9abf9b9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -542,6 +542,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index ca10e65de8a4..37a344a5b621 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -422,6 +422,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py index 777be883cb9d..d039343b8f99 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -553,6 +553,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py index ce7ad3b0dfe9..04c298f931cb 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -454,6 +454,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py index 9e91986896bd..78fbc9834e19 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -392,6 +392,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py index 2978972200c7..df683fa70d27 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py @@ -606,6 +606,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 2212651fbb5b..c198a7999dd9 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -305,6 +305,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 62d3e83a4790..e06d43ee4a58 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -165,6 +165,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index de4c2ac9b7f4..9c1d92f4be3b 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -222,6 +222,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py index f0f71080d0a3..0bb17157cd88 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -418,6 +418,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index b225fd71edf8..b059704d966d 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -252,6 +252,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 6f83071f3e85..93c1bb4d1e91 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -112,6 +112,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4fd6a43a955a..72e27cb90f17 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -589,6 +589,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -1060,6 +1062,8 @@ def __call__( xm.mark_step() if not output_type == "latent": + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 7801b0d01dff..2c7ea793f96f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -385,6 +385,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 93a8bd160318..39e886e5507d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -176,6 +176,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 9cd5673c9359..2e9e5132815c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -616,6 +616,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index fd89b195c778..2ba3c068cde2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -751,6 +751,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index ffe02ae679e5..575c2144a199 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -281,6 +281,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 4cbbe17531ef..b12e50c8891d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -410,6 +410,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 41811f8f2c0e..8ac2b7d90966 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -473,6 +473,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 2556d5e57b6d..c16f4e380c2a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -435,6 +435,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py index 8f40fa72a25c..e5551867b698 100644 --- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -483,6 +483,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 2b86470dbff1..c3de4f6944eb 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -626,6 +626,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py index 122701ff923f..8be9c6f6018a 100755 --- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -383,6 +383,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 2fc79c0610f0..062e2fdfd818 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -542,6 +542,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index cd59cf51869d..e726cc797cd9 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -337,6 +337,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index c32052d2e4d0..1a4aa2093dba 100644 --- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -457,6 +457,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 1a938aaf9423..9a61491d2b3e 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -500,6 +500,8 @@ def decode_latents(self, latents): deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents + if torch.cuda.is_available() and torch.version.hip: + latents = latents.to(dtype=self.vae.dtype) image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16