21
21
import torch
22
22
from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
23
23
24
+ from ...image_processor import VaeImageProcessor
24
25
from ...loaders import TextualInversionLoaderMixin
25
26
from ...models import AutoencoderKL , UNet2DConditionModel
26
27
from ...models .attention_processor import AttnProcessor2_0 , LoRAXFormersAttnProcessor , XFormersAttnProcessor
34
35
35
36
36
37
def preprocess (image ):
38
+ warnings .warn (
39
+ "The preprocess method is deprecated and will be removed in a future version. Please"
40
+ " use VaeImageProcessor.preprocess instead" ,
41
+ FutureWarning ,
42
+ )
37
43
if isinstance (image , torch .Tensor ):
38
44
return image
39
45
elif isinstance (image , PIL .Image .Image ):
@@ -125,6 +131,8 @@ def __init__(
125
131
watermarker = watermarker ,
126
132
feature_extractor = feature_extractor ,
127
133
)
134
+ self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
135
+ self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , resample = "bicubic" )
128
136
self .register_to_config (max_noise_level = max_noise_level )
129
137
130
138
def enable_sequential_cpu_offload (self , gpu_id = 0 ):
@@ -432,14 +440,15 @@ def check_inputs(
432
440
if (
433
441
not isinstance (image , torch .Tensor )
434
442
and not isinstance (image , PIL .Image .Image )
443
+ and not isinstance (image , np .ndarray )
435
444
and not isinstance (image , list )
436
445
):
437
446
raise ValueError (
438
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is { type (image )} "
447
+ f"`image` has to be of type `torch.Tensor`, `np.ndarray`, ` PIL.Image.Image` or `list` but is { type (image )} "
439
448
)
440
449
441
- # verify batch size of prompt and image are same if image is a list or tensor
442
- if isinstance (image , list ) or isinstance (image , torch .Tensor ):
450
+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
451
+ if isinstance (image , list ) or isinstance (image , torch .Tensor ) or isinstance ( image , np . ndarray ) :
443
452
if isinstance (prompt , str ):
444
453
batch_size = 1
445
454
else :
@@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
483
492
def __call__ (
484
493
self ,
485
494
prompt : Union [str , List [str ]] = None ,
486
- image : Union [torch .FloatTensor , PIL .Image .Image , List [PIL .Image .Image ]] = None ,
495
+ image : Union [
496
+ torch .FloatTensor ,
497
+ PIL .Image .Image ,
498
+ np .ndarray ,
499
+ List [torch .FloatTensor ],
500
+ List [PIL .Image .Image ],
501
+ List [np .ndarray ],
502
+ ] = None ,
487
503
num_inference_steps : int = 75 ,
488
504
guidance_scale : float = 9.0 ,
489
505
noise_level : int = 20 ,
@@ -506,7 +522,7 @@ def __call__(
506
522
prompt (`str` or `List[str]`, *optional*):
507
523
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
508
524
instead.
509
- image (`PIL.Image.Image` or List[` PIL.Image.Image`] or `torch.FloatTensor `):
525
+ image (`torch.FloatTensor`, ` PIL.Image.Image`, `np.ndarray`, ` List[torch.FloatTensor]`, `List[ PIL.Image.Image]`, or `List[np.ndarray] `):
510
526
`Image`, or tensor representing an image batch which will be upscaled. *
511
527
num_inference_steps (`int`, *optional*, defaults to 50):
512
528
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -627,7 +643,7 @@ def __call__(
627
643
)
628
644
629
645
# 4. Preprocess image
630
- image = preprocess (image )
646
+ image = self . image_processor . preprocess (image )
631
647
image = image .to (dtype = prompt_embeds .dtype , device = device )
632
648
633
649
# 5. set timesteps
@@ -723,25 +739,25 @@ def __call__(
723
739
else :
724
740
latents = latents .float ()
725
741
726
- # 11. Convert to PIL
727
- if output_type == "pil" :
728
- image = self .decode_latents (latents )
729
-
742
+ # post-processing
743
+ if not output_type == "latent" :
744
+ image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
730
745
image , has_nsfw_concept , _ = self .run_safety_checker (image , device , prompt_embeds .dtype )
731
-
732
- image = self .numpy_to_pil (image )
733
-
734
- # 11. Apply watermark
735
- if self .watermarker is not None :
736
- image = self .watermarker .apply_watermark (image )
737
- elif output_type == "pt" :
738
- latents = 1 / self .vae .config .scaling_factor * latents
739
- image = self .vae .decode (latents ).sample
740
- has_nsfw_concept = None
741
746
else :
742
- image = self . decode_latents ( latents )
747
+ image = latents
743
748
has_nsfw_concept = None
744
749
750
+ if has_nsfw_concept is None :
751
+ do_denormalize = [True ] * image .shape [0 ]
752
+ else :
753
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept ]
754
+
755
+ image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
756
+
757
+ # 11. Apply watermark
758
+ if output_type == "pil" and self .watermarker is not None :
759
+ image = self .watermarker .apply_watermark (image )
760
+
745
761
# Offload last model to CPU
746
762
if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
747
763
self .final_offload_hook .offload ()
0 commit comments