Skip to content

Commit bb1b76d

Browse files
a-r-r-o-wyiyixuxu
andauthored
IPAdapterTesterMixin (#6862)
* begin IPAdapterTesterMixin --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent e4b8f17 commit bb1b76d

31 files changed

+379
-66
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,11 @@ def __call__(
797797
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
798798

799799
# 7. Add image embeds for IP-Adapter
800-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
800+
added_cond_kwargs = (
801+
{"image_embeds": image_embeds}
802+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
803+
else None
804+
)
801805

802806
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
803807
for free_init_iter in range(num_free_init_iters):

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
441441

442442
return image_embeds, uncond_image_embeds
443443

444+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
445+
def prepare_ip_adapter_image_embeds(
446+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
447+
):
448+
if ip_adapter_image_embeds is None:
449+
if not isinstance(ip_adapter_image, list):
450+
ip_adapter_image = [ip_adapter_image]
451+
452+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
453+
raise ValueError(
454+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
455+
)
456+
457+
image_embeds = []
458+
for single_ip_adapter_image, image_proj_layer in zip(
459+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
460+
):
461+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
462+
single_image_embeds, single_negative_image_embeds = self.encode_image(
463+
single_ip_adapter_image, device, 1, output_hidden_state
464+
)
465+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
466+
single_negative_image_embeds = torch.stack(
467+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
468+
)
469+
470+
if self.do_classifier_free_guidance:
471+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
472+
single_image_embeds = single_image_embeds.to(device)
473+
474+
image_embeds.append(single_image_embeds)
475+
else:
476+
image_embeds = ip_adapter_image_embeds
477+
return image_embeds
478+
444479
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
445480
def decode_latents(self, latents):
446481
latents = 1 / self.vae.config.scaling_factor * latents
@@ -735,6 +770,7 @@ def __call__(
735770
prompt_embeds: Optional[torch.FloatTensor] = None,
736771
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
737772
ip_adapter_image: Optional[PipelineImageInput] = None,
773+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
738774
output_type: Optional[str] = "pil",
739775
return_dict: bool = True,
740776
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -784,6 +820,9 @@ def __call__(
784820
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
785821
ip_adapter_image: (`PipelineImageInput`, *optional*):
786822
Optional image input to work with IP Adapters.
823+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
824+
Pre-generated image embeddings for IP-Adapter. If not
825+
provided, embeddings are computed from the `ip_adapter_image` input argument.
787826
output_type (`str`, *optional*, defaults to `"pil"`):
788827
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
789828
`np.array`.
@@ -870,13 +909,10 @@ def __call__(
870909
if self.do_classifier_free_guidance:
871910
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
872911

873-
if ip_adapter_image is not None:
874-
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
875-
image_embeds, negative_image_embeds = self.encode_image(
876-
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
912+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
913+
image_embeds = self.prepare_ip_adapter_image_embeds(
914+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
877915
)
878-
if self.do_classifier_free_guidance:
879-
image_embeds = torch.cat([negative_image_embeds, image_embeds])
880916

881917
# 4. Prepare timesteps
882918
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -902,7 +938,11 @@ def __call__(
902938
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
903939

904940
# 7. Add image embeds for IP-Adapter
905-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
941+
added_cond_kwargs = (
942+
{"image_embeds": image_embeds}
943+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
944+
else None
945+
)
906946

907947
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
908948
for free_init_iter in range(num_free_init_iters):

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,11 @@ def __call__(
12061206
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
12071207

12081208
# 7.1 Add image embeds for IP-Adapter
1209-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1209+
added_cond_kwargs = (
1210+
{"image_embeds": image_embeds}
1211+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1212+
else None
1213+
)
12101214

12111215
# 7.2 Create tensor stating which controlnets to keep
12121216
controlnet_keep = []

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,11 @@ def __call__(
12061206
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
12071207

12081208
# 7.1 Add image embeds for IP-Adapter
1209-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1209+
added_cond_kwargs = (
1210+
{"image_embeds": image_embeds}
1211+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1212+
else None
1213+
)
12101214

12111215
# 7.2 Create tensor stating which controlnets to keep
12121216
controlnet_keep = []

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,11 @@ def __call__(
14951495
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
14961496

14971497
# 7.1 Add image embeds for IP-Adapter
1498-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1498+
added_cond_kwargs = (
1499+
{"image_embeds": image_embeds}
1500+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1501+
else None
1502+
)
14991503

15001504
# 7.2 Create tensor stating which controlnets to keep
15011505
controlnet_keep = []

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
477477

478478
return image_embeds, uncond_image_embeds
479479

480+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
480481
def prepare_ip_adapter_image_embeds(
481-
self, ip_adapter_image, ip_adapter_image_embeds, do_classifier_free_guidance, device, num_images_per_prompt
482+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
482483
):
483484
if ip_adapter_image_embeds is None:
484485
if not isinstance(ip_adapter_image, list):
@@ -502,7 +503,7 @@ def prepare_ip_adapter_image_embeds(
502503
[single_negative_image_embeds] * num_images_per_prompt, dim=0
503504
)
504505

505-
if do_classifier_free_guidance:
506+
if self.do_classifier_free_guidance:
506507
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
507508
single_image_embeds = single_image_embeds.to(device)
508509

@@ -699,6 +700,10 @@ def cross_attention_kwargs(self):
699700
def clip_skip(self):
700701
return self._clip_skip
701702

703+
@property
704+
def do_classifier_free_guidance(self):
705+
return False
706+
702707
@property
703708
def num_timesteps(self):
704709
return self._num_timesteps
@@ -845,7 +850,7 @@ def __call__(
845850

846851
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
847852
image_embeds = self.prepare_ip_adapter_image_embeds(
848-
ip_adapter_image, ip_adapter_image_embeds, False, device, batch_size * num_images_per_prompt
853+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
849854
)
850855

851856
# 3. Encode input prompt
@@ -860,7 +865,7 @@ def __call__(
860865
prompt,
861866
device,
862867
num_images_per_prompt,
863-
False,
868+
self.do_classifier_free_guidance,
864869
negative_prompt=None,
865870
prompt_embeds=prompt_embeds,
866871
negative_prompt_embeds=None,
@@ -906,7 +911,11 @@ def __call__(
906911
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
907912

908913
# 7.1 Add image embeds for IP-Adapter
909-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
914+
added_cond_kwargs = (
915+
{"image_embeds": image_embeds}
916+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
917+
else None
918+
)
910919

911920
# 8. LCM Multistep Sampling Loop
912921
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
461461

462462
return image_embeds, uncond_image_embeds
463463

464+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
465+
def prepare_ip_adapter_image_embeds(
466+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
467+
):
468+
if ip_adapter_image_embeds is None:
469+
if not isinstance(ip_adapter_image, list):
470+
ip_adapter_image = [ip_adapter_image]
471+
472+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
473+
raise ValueError(
474+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
475+
)
476+
477+
image_embeds = []
478+
for single_ip_adapter_image, image_proj_layer in zip(
479+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
480+
):
481+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
482+
single_image_embeds, single_negative_image_embeds = self.encode_image(
483+
single_ip_adapter_image, device, 1, output_hidden_state
484+
)
485+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
486+
single_negative_image_embeds = torch.stack(
487+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
488+
)
489+
490+
if self.do_classifier_free_guidance:
491+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
492+
single_image_embeds = single_image_embeds.to(device)
493+
494+
image_embeds.append(single_image_embeds)
495+
else:
496+
image_embeds = ip_adapter_image_embeds
497+
return image_embeds
498+
464499
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
465500
def run_safety_checker(self, image, device, dtype):
466501
if self.safety_checker is None:
@@ -590,6 +625,10 @@ def cross_attention_kwargs(self):
590625
def clip_skip(self):
591626
return self._clip_skip
592627

628+
@property
629+
def do_classifier_free_guidance(self):
630+
return False
631+
593632
@property
594633
def num_timesteps(self):
595634
return self._num_timesteps
@@ -610,6 +649,7 @@ def __call__(
610649
latents: Optional[torch.FloatTensor] = None,
611650
prompt_embeds: Optional[torch.FloatTensor] = None,
612651
ip_adapter_image: Optional[PipelineImageInput] = None,
652+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
613653
output_type: Optional[str] = "pil",
614654
return_dict: bool = True,
615655
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -660,6 +700,9 @@ def __call__(
660700
provided, text embeddings are generated from the `prompt` input argument.
661701
ip_adapter_image: (`PipelineImageInput`, *optional*):
662702
Optional image input to work with IP Adapters.
703+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
704+
Pre-generated image embeddings for IP-Adapter. If not
705+
provided, embeddings are computed from the `ip_adapter_image` input argument.
663706
output_type (`str`, *optional*, defaults to `"pil"`):
664707
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
665708
return_dict (`bool`, *optional*, defaults to `True`):
@@ -726,12 +769,10 @@ def __call__(
726769
batch_size = prompt_embeds.shape[0]
727770

728771
device = self._execution_device
729-
# do_classifier_free_guidance = guidance_scale > 1.0
730772

731-
if ip_adapter_image is not None:
732-
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
733-
image_embeds, negative_image_embeds = self.encode_image(
734-
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
773+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
774+
image_embeds = self.prepare_ip_adapter_image_embeds(
775+
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
735776
)
736777

737778
# 3. Encode input prompt
@@ -746,7 +787,7 @@ def __call__(
746787
prompt,
747788
device,
748789
num_images_per_prompt,
749-
False,
790+
self.do_classifier_free_guidance,
750791
negative_prompt=None,
751792
prompt_embeds=prompt_embeds,
752793
negative_prompt_embeds=None,
@@ -786,7 +827,11 @@ def __call__(
786827
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
787828

788829
# 7.1 Add image embeds for IP-Adapter
789-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
830+
added_cond_kwargs = (
831+
{"image_embeds": image_embeds}
832+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
833+
else None
834+
)
790835

791836
# 8. LCM MultiStep Sampling Loop:
792837
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

src/diffusers/pipelines/pia/pipeline_pia.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,11 @@ def __call__(
987987
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
988988

989989
# 7. Add image embeds for IP-Adapter
990-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
990+
added_cond_kwargs = (
991+
{"image_embeds": image_embeds}
992+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
993+
else None
994+
)
991995

992996
# 8. Denoising loop
993997
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,11 @@ def __call__(
11111111
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
11121112

11131113
# 7.1 Add image embeds for IP-Adapter
1114-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1114+
added_cond_kwargs = (
1115+
{"image_embeds": image_embeds}
1116+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1117+
else None
1118+
)
11151119

11161120
# 7.2 Optionally get Guidance Scale Embedding
11171121
timestep_cond = None

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,11 @@ def __call__(
13971397
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
13981398

13991399
# 9.1 Add image embeds for IP-Adapter
1400-
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1400+
added_cond_kwargs = (
1401+
{"image_embeds": image_embeds}
1402+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1403+
else None
1404+
)
14011405

14021406
# 9.2 Optionally get Guidance Scale Embedding
14031407
timestep_cond = None

0 commit comments

Comments
 (0)