Skip to content

Commit da2ce1a

Browse files
Allow return pt x4 (#3236)
* Add all files * update
1 parent e51f19a commit da2ce1a

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -697,25 +697,30 @@ def __call__(
697697
# 10. Post-processing
698698
# make sure the VAE is in float32 mode, as it overflows in float16
699699
self.vae.to(dtype=torch.float32)
700-
image = self.decode_latents(latents.float())
701-
702-
# Offload last model to CPU
703-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
704-
self.final_offload_hook.offload()
705700

706701
# 11. Convert to PIL
707702
# has_nsfw_concept = False
708703
if output_type == "pil":
704+
image = self.decode_latents(latents.float())
709705
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
710706

711707
image = self.numpy_to_pil(image)
712708

713709
# 11. Apply watermark
714710
if self.watermarker is not None:
715711
image = self.watermarker.apply_watermark(image)
712+
elif output_type == "pt":
713+
latents = 1 / self.vae.config.scaling_factor * latents.float()
714+
image = self.vae.decode(latents).sample
715+
has_nsfw_concept = None
716716
else:
717+
image = self.decode_latents(latents.float())
717718
has_nsfw_concept = None
718719

720+
# Offload last model to CPU
721+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
722+
self.final_offload_hook.offload()
723+
719724
if not return_dict:
720725
return (image, has_nsfw_concept)
721726

0 commit comments

Comments
 (0)