18
18
import numpy as np
19
19
import PIL
20
20
import torch
21
+ import torch .nn .functional as F
21
22
from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
22
23
23
24
from ...loaders import TextualInversionLoaderMixin
@@ -698,10 +699,22 @@ def __call__(
698
699
# make sure the VAE is in float32 mode, as it overflows in float16
699
700
self .vae .to (dtype = torch .float32 )
700
701
702
+ # TODO(Patrick, William) - clean up when attention is refactored
703
+ use_torch_2_0_attn = hasattr (F , "scaled_dot_product_attention" )
704
+ use_xformers = self .vae .decoder .mid_block .attentions [0 ]._use_memory_efficient_attention_xformers
705
+ # if xformers or torch_2_0 is used attention block does not need
706
+ # to be in float32 which can save lots of memory
707
+ if not use_torch_2_0_attn and not use_xformers :
708
+ self .vae .post_quant_conv .to (latents .dtype )
709
+ self .vae .decoder .conv_in .to (latents .dtype )
710
+ self .vae .decoder .mid_block .to (latents .dtype )
711
+ else :
712
+ latents = latents .float ()
713
+
701
714
# 11. Convert to PIL
702
- # has_nsfw_concept = False
703
715
if output_type == "pil" :
704
- image = self .decode_latents (latents .float ())
716
+ image = self .decode_latents (latents )
717
+
705
718
image , has_nsfw_concept , _ = self .run_safety_checker (image , device , prompt_embeds .dtype )
706
719
707
720
image = self .numpy_to_pil (image )
@@ -710,11 +723,11 @@ def __call__(
710
723
if self .watermarker is not None :
711
724
image = self .watermarker .apply_watermark (image )
712
725
elif output_type == "pt" :
713
- latents = 1 / self .vae .config .scaling_factor * latents . float ()
726
+ latents = 1 / self .vae .config .scaling_factor * latents
714
727
image = self .vae .decode (latents ).sample
715
728
has_nsfw_concept = None
716
729
else :
717
- image = self .decode_latents (latents . float () )
730
+ image = self .decode_latents (latents )
718
731
has_nsfw_concept = None
719
732
720
733
# Offload last model to CPU
0 commit comments