21
21
import torch
22
22
from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
23
23
24
+ from ...image_processor import VaeImageProcessor
24
25
from ...loaders import TextualInversionLoaderMixin
25
26
from ...models import AutoencoderKL , UNet2DConditionModel
26
27
from ...models .attention_processor import AttnProcessor2_0 , LoRAXFormersAttnProcessor , XFormersAttnProcessor
@@ -125,6 +126,7 @@ def __init__(
125
126
watermarker = watermarker ,
126
127
feature_extractor = feature_extractor ,
127
128
)
129
+ self .image_processor = VaeImageProcessor (vae_scale_factor = 64 , resample = "bicubic" )
128
130
self .register_to_config (max_noise_level = max_noise_level )
129
131
130
132
def enable_sequential_cpu_offload (self , gpu_id = 0 ):
@@ -432,14 +434,15 @@ def check_inputs(
432
434
if (
433
435
not isinstance (image , torch .Tensor )
434
436
and not isinstance (image , PIL .Image .Image )
437
+ and not isinstance (image , np .ndarray )
435
438
and not isinstance (image , list )
436
439
):
437
440
raise ValueError (
438
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is { type (image )} "
441
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray` or `list` but is { type (image )} "
439
442
)
440
443
441
- # verify batch size of prompt and image are same if image is a list or tensor
442
- if isinstance (image , list ) or isinstance (image , torch .Tensor ):
444
+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
445
+ if isinstance (image , list ) or isinstance (image , torch .Tensor ) or isinstance ( image , np . ndarray ) :
443
446
if isinstance (prompt , str ):
444
447
batch_size = 1
445
448
else :
@@ -483,7 +486,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
483
486
def __call__ (
484
487
self ,
485
488
prompt : Union [str , List [str ]] = None ,
486
- image : Union [torch .FloatTensor , PIL .Image .Image , List [PIL .Image .Image ]] = None ,
489
+ image : Union [
490
+ torch .FloatTensor ,
491
+ PIL .Image .Image ,
492
+ np .ndarray ,
493
+ List [torch .FloatTensor ],
494
+ List [PIL .Image .Image ],
495
+ List [np .ndarray ],
496
+ ] = None ,
487
497
num_inference_steps : int = 75 ,
488
498
guidance_scale : float = 9.0 ,
489
499
noise_level : int = 20 ,
@@ -506,7 +516,7 @@ def __call__(
506
516
prompt (`str` or `List[str]`, *optional*):
507
517
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
508
518
instead.
509
- image (`PIL.Image.Image` or List[` PIL.Image.Image`] or `torch.FloatTensor `):
519
+ image (`torch.FloatTensor`, ` PIL.Image.Image`, `np.ndarray`, ` List[torch.FloatTensor]`, `List[ PIL.Image.Image]`, or `List[np.ndarray] `):
510
520
`Image`, or tensor representing an image batch which will be upscaled. *
511
521
num_inference_steps (`int`, *optional*, defaults to 50):
512
522
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -627,7 +637,7 @@ def __call__(
627
637
)
628
638
629
639
# 4. Preprocess image
630
- image = preprocess (image )
640
+ image = self . image_processor . preprocess (image )
631
641
image = image .to (dtype = prompt_embeds .dtype , device = device )
632
642
633
643
# 5. set timesteps
@@ -723,25 +733,24 @@ def __call__(
723
733
else :
724
734
latents = latents .float ()
725
735
726
- # 11. Convert to PIL
727
- if output_type == "pil" :
728
- image = self .decode_latents (latents )
729
-
736
+ if not output_type == "latent" :
737
+ image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
730
738
image , has_nsfw_concept , _ = self .run_safety_checker (image , device , prompt_embeds .dtype )
731
-
732
- image = self .numpy_to_pil (image )
733
-
734
- # 11. Apply watermark
735
- if self .watermarker is not None :
736
- image = self .watermarker .apply_watermark (image )
737
- elif output_type == "pt" :
738
- latents = 1 / self .vae .config .scaling_factor * latents
739
- image = self .vae .decode (latents ).sample
740
- has_nsfw_concept = None
741
739
else :
742
- image = self . decode_latents ( latents )
740
+ image = latents
743
741
has_nsfw_concept = None
744
742
743
+ if has_nsfw_concept is None :
744
+ do_denormalize = [True ] * image .shape [0 ]
745
+ else :
746
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept ]
747
+
748
+ image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
749
+
750
+ # 11. Apply watermark
751
+ if output_type == "pil" and self .watermarker is not None :
752
+ image = self .watermarker .apply_watermark (image )
753
+
745
754
# Offload last model to CPU
746
755
if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
747
756
self .final_offload_hook .offload ()
0 commit comments