Skip to content

Commit d0f1073

Browse files
committed
adapter for StableDiffusionControlNetImg2ImgPipeline
1 parent 0eeee61 commit d0f1073

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import PIL.Image
2020
import torch
2121
import torch.nn.functional as F
22-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
2323

2424
from ...image_processor import PipelineImageInput, VaeImageProcessor
25-
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
25+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2626
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
2727
from ...models.lora import adjust_lora_scale_text_encoder
2828
from ...schedulers import KarrasDiffusionSchedulers
@@ -126,7 +126,7 @@ def prepare_image(image):
126126

127127

128128
class StableDiffusionControlNetImg2ImgPipeline(
129-
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
129+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
130130
):
131131
r"""
132132
Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -136,7 +136,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
136136
137137
The pipeline also inherits the following loading methods:
138138
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
139-
139+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
140140
Args:
141141
vae ([`AutoencoderKL`]):
142142
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -162,7 +162,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
162162
"""
163163

164164
model_cpu_offload_seq = "text_encoder->unet->vae"
165-
_optional_components = ["safety_checker", "feature_extractor"]
165+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
166166
_exclude_from_cpu_offload = ["safety_checker"]
167167
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168168

@@ -176,6 +176,7 @@ def __init__(
176176
scheduler: KarrasDiffusionSchedulers,
177177
safety_checker: StableDiffusionSafetyChecker,
178178
feature_extractor: CLIPImageProcessor,
179+
image_encoder: CLIPVisionModelWithProjection = None,
179180
requires_safety_checker: bool = True,
180181
):
181182
super().__init__()
@@ -208,6 +209,7 @@ def __init__(
208209
scheduler=scheduler,
209210
safety_checker=safety_checker,
210211
feature_extractor=feature_extractor,
212+
image_encoder=image_encoder,
211213
)
212214
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
213215
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -464,6 +466,20 @@ def encode_prompt(
464466

465467
return prompt_embeds, negative_prompt_embeds
466468

469+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
470+
def encode_image(self, image, device, num_images_per_prompt):
471+
dtype = next(self.image_encoder.parameters()).dtype
472+
473+
if not isinstance(image, torch.Tensor):
474+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
475+
476+
image = image.to(device=device, dtype=dtype)
477+
image_embeds = self.image_encoder(image).image_embeds
478+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
479+
480+
uncond_image_embeds = torch.zeros_like(image_embeds)
481+
return image_embeds, uncond_image_embeds
482+
467483
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
468484
def run_safety_checker(self, image, device, dtype):
469485
if self.safety_checker is None:
@@ -857,6 +873,7 @@ def __call__(
857873
latents: Optional[torch.FloatTensor] = None,
858874
prompt_embeds: Optional[torch.FloatTensor] = None,
859875
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
876+
ip_adapter_image: Optional[PipelineImageInput] = None,
860877
output_type: Optional[str] = "pil",
861878
return_dict: bool = True,
862879
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -918,6 +935,7 @@ def __call__(
918935
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
919936
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
920937
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
938+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
921939
output_type (`str`, *optional*, defaults to `"pil"`):
922940
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
923941
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1049,6 +1067,11 @@ def __call__(
10491067
if self.do_classifier_free_guidance:
10501068
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
10511069

1070+
if ip_adapter_image is not None:
1071+
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1072+
if self.do_classifier_free_guidance:
1073+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
1074+
10521075
# 4. Prepare image
10531076
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
10541077

@@ -1107,7 +1130,10 @@ def __call__(
11071130
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
11081131
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
11091132

1110-
# 7.1 Create tensor stating which controlnets to keep
1133+
# 7.1 Add image embeds for IP-Adapter
1134+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1135+
1136+
# 7.2 Create tensor stating which controlnets to keep
11111137
controlnet_keep = []
11121138
for i in range(len(timesteps)):
11131139
keeps = [
@@ -1167,6 +1193,7 @@ def __call__(
11671193
cross_attention_kwargs=self.cross_attention_kwargs,
11681194
down_block_additional_residuals=down_block_res_samples,
11691195
mid_block_additional_residual=mid_block_res_sample,
1196+
added_cond_kwargs=added_cond_kwargs,
11701197
return_dict=False,
11711198
)[0]
11721199

tests/pipelines/controlnet/test_controlnet_img2img.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def get_dummy_components(self):
134134
"tokenizer": tokenizer,
135135
"safety_checker": None,
136136
"feature_extractor": None,
137+
"image_encoder": None,
137138
}
138139
return components
139140

@@ -273,6 +274,7 @@ def init_weights(m):
273274
"tokenizer": tokenizer,
274275
"safety_checker": None,
275276
"feature_extractor": None,
277+
"image_encoder": None,
276278
}
277279
return components
278280

0 commit comments

Comments
 (0)