1919import PIL .Image
2020import torch
2121import torch .nn .functional as F
22- from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
22+ from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
2323
2424from ...image_processor import PipelineImageInput , VaeImageProcessor
25- from ...loaders import FromSingleFileMixin , LoraLoaderMixin , TextualInversionLoaderMixin
25+ from ...loaders import FromSingleFileMixin , IPAdapterMixin , LoraLoaderMixin , TextualInversionLoaderMixin
2626from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
2727from ...models .lora import adjust_lora_scale_text_encoder
2828from ...schedulers import KarrasDiffusionSchedulers
@@ -126,7 +126,7 @@ def prepare_image(image):
126126
127127
128128class StableDiffusionControlNetImg2ImgPipeline (
129- DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , FromSingleFileMixin
129+ DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , IPAdapterMixin , FromSingleFileMixin
130130):
131131 r"""
132132 Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -136,7 +136,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
136136
137137 The pipeline also inherits the following loading methods:
138138 - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
139-
139+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
140140 Args:
141141 vae ([`AutoencoderKL`]):
142142 Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -162,7 +162,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
162162 """
163163
164164 model_cpu_offload_seq = "text_encoder->unet->vae"
165- _optional_components = ["safety_checker" , "feature_extractor" ]
165+ _optional_components = ["safety_checker" , "feature_extractor" , "image_encoder" ]
166166 _exclude_from_cpu_offload = ["safety_checker" ]
167167 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
168168
@@ -176,6 +176,7 @@ def __init__(
176176 scheduler : KarrasDiffusionSchedulers ,
177177 safety_checker : StableDiffusionSafetyChecker ,
178178 feature_extractor : CLIPImageProcessor ,
179+ image_encoder : CLIPVisionModelWithProjection = None ,
179180 requires_safety_checker : bool = True ,
180181 ):
181182 super ().__init__ ()
@@ -208,6 +209,7 @@ def __init__(
208209 scheduler = scheduler ,
209210 safety_checker = safety_checker ,
210211 feature_extractor = feature_extractor ,
212+ image_encoder = image_encoder ,
211213 )
212214 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
213215 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True )
@@ -464,6 +466,20 @@ def encode_prompt(
464466
465467 return prompt_embeds , negative_prompt_embeds
466468
469+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
470+ def encode_image (self , image , device , num_images_per_prompt ):
471+ dtype = next (self .image_encoder .parameters ()).dtype
472+
473+ if not isinstance (image , torch .Tensor ):
474+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
475+
476+ image = image .to (device = device , dtype = dtype )
477+ image_embeds = self .image_encoder (image ).image_embeds
478+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
479+
480+ uncond_image_embeds = torch .zeros_like (image_embeds )
481+ return image_embeds , uncond_image_embeds
482+
467483 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
468484 def run_safety_checker (self , image , device , dtype ):
469485 if self .safety_checker is None :
@@ -857,6 +873,7 @@ def __call__(
857873 latents : Optional [torch .FloatTensor ] = None ,
858874 prompt_embeds : Optional [torch .FloatTensor ] = None ,
859875 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
876+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
860877 output_type : Optional [str ] = "pil" ,
861878 return_dict : bool = True ,
862879 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -918,6 +935,7 @@ def __call__(
918935 negative_prompt_embeds (`torch.FloatTensor`, *optional*):
919936 Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
920937 not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
938+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
921939 output_type (`str`, *optional*, defaults to `"pil"`):
922940 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
923941 return_dict (`bool`, *optional*, defaults to `True`):
@@ -1049,6 +1067,11 @@ def __call__(
10491067 if self .do_classifier_free_guidance :
10501068 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
10511069
1070+ if ip_adapter_image is not None :
1071+ image_embeds , negative_image_embeds = self .encode_image (ip_adapter_image , device , num_images_per_prompt )
1072+ if self .do_classifier_free_guidance :
1073+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1074+
10521075 # 4. Prepare image
10531076 image = self .image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
10541077
@@ -1107,7 +1130,10 @@ def __call__(
11071130 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
11081131 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
11091132
1110- # 7.1 Create tensor stating which controlnets to keep
1133+ # 7.1 Add image embeds for IP-Adapter
1134+ added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
1135+
1136+ # 7.2 Create tensor stating which controlnets to keep
11111137 controlnet_keep = []
11121138 for i in range (len (timesteps )):
11131139 keeps = [
@@ -1167,6 +1193,7 @@ def __call__(
11671193 cross_attention_kwargs = self .cross_attention_kwargs ,
11681194 down_block_additional_residuals = down_block_res_samples ,
11691195 mid_block_additional_residual = mid_block_res_sample ,
1196+ added_cond_kwargs = added_cond_kwargs ,
11701197 return_dict = False ,
11711198 )[0 ]
11721199
0 commit comments