diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 9e2e428eaf91..11a0512d7d93 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -21,10 +21,10 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -237,7 +237,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False class StableDiffusionControlNetInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. @@ -247,6 +247,7 @@ class StableDiffusionControlNetInpaintPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters @@ -284,7 +285,7 @@ class StableDiffusionControlNetInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -298,6 +299,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -330,6 +332,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -589,6 +592,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -1049,6 +1066,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1127,6 +1145,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -1260,6 +1279,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( @@ -1295,7 +1319,7 @@ def __call__( else: assert False - # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = init_image.to(dtype=torch.float32) @@ -1356,7 +1380,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Create tensor stating which controlnets to keep + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1419,6 +1446,7 @@ def __call__( cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index a9140f3d5a31..7c3371c197d4 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -132,6 +132,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -248,6 +249,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -342,6 +344,7 @@ def init_weights(m): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components