Skip to content

Commit b65928b

Browse files
fabioriganoyiyixuxupatrickvonplaten
authored
Add support for IPAdapterFull (#5911)
* Add support for IPAdapterFull Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 6bf1ca2 commit b65928b

File tree

4 files changed

+122
-1
lines changed

4 files changed

+122
-1
lines changed

docs/source/en/using-diffusers/loading_adapters.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,69 @@ image.save("sdxl_t2i.png")
485485
</div>
486486
</div>
487487

488+
You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations.
489+
Weights are loaded with the same method used for the other IP-Adapters.
490+
491+
```python
492+
# Load ip-adapter-full-face_sd15.bin
493+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
494+
```
495+
496+
<Tip>
497+
498+
It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model.
499+
500+
501+
</Tip>
502+
503+
```python
504+
import torch
505+
from diffusers import StableDiffusionPipeline, DDIMScheduler
506+
from diffusers.utils import load_image
507+
508+
noise_scheduler = DDIMScheduler(
509+
num_train_timesteps=1000,
510+
beta_start=0.00085,
511+
beta_end=0.012,
512+
beta_schedule="scaled_linear",
513+
clip_sample=False,
514+
set_alpha_to_one=False,
515+
steps_offset=1
516+
)
517+
518+
pipeline = StableDiffusionPipeline.from_pretrained(
519+
"runwayml/stable-diffusion-v1-5",
520+
torch_dtype=torch.float16,
521+
scheduler=noise_scheduler,
522+
).to("cuda")
523+
524+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
525+
526+
pipeline.set_ip_adapter_scale(0.7)
527+
528+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
529+
530+
generator = torch.Generator(device="cpu").manual_seed(33)
531+
532+
image = pipeline(
533+
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
534+
ip_adapter_image=image,
535+
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
536+
num_inference_steps=50, num_images_per_prompt=1, width=512, height=704,
537+
generator=generator,
538+
).images[0]
539+
```
540+
541+
<div class="flex flex-row gap-4">
542+
<div class="flex-1">
543+
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png"/>
544+
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
545+
</div>
546+
<div class="flex-1">
547+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ipadapter_full_face_output.png"/>
548+
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
549+
</div>
550+
</div>
488551

489552
### LCM-Lora
490553

src/diffusers/loaders/unet.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from huggingface_hub.utils import validate_hf_hub_args
2323
from torch import nn
2424

25-
from ..models.embeddings import ImageProjection, Resampler
25+
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
2626
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2727
from ..utils import (
2828
USE_PEFT_BACKEND,
@@ -675,6 +675,9 @@ def _load_ip_adapter_weights(self, state_dict):
675675
if "proj.weight" in state_dict["image_proj"]:
676676
# IP-Adapter
677677
num_image_text_embeds = 4
678+
elif "proj.3.weight" in state_dict["image_proj"]:
679+
# IP-Adapter Full Face
680+
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
678681
else:
679682
# IP-Adapter Plus
680683
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
@@ -744,8 +747,32 @@ def _load_ip_adapter_weights(self, state_dict):
744747
"norm.bias": state_dict["image_proj"]["norm.bias"],
745748
}
746749
)
750+
image_projection.load_state_dict(image_proj_state_dict)
751+
del image_proj_state_dict
747752

753+
elif "proj.3.weight" in state_dict["image_proj"]:
754+
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
755+
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
756+
757+
image_projection = MLPProjection(
758+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
759+
)
760+
image_projection.to(dtype=self.dtype, device=self.device)
761+
762+
# load image projection layer weights
763+
image_proj_state_dict = {}
764+
image_proj_state_dict.update(
765+
{
766+
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
767+
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
768+
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
769+
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
770+
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
771+
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
772+
}
773+
)
748774
image_projection.load_state_dict(image_proj_state_dict)
775+
del image_proj_state_dict
749776

750777
else:
751778
# IP-Adapter Plus

src/diffusers/models/embeddings.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,18 @@ def forward(self, image_embeds: torch.FloatTensor):
461461
return image_embeds
462462

463463

464+
class MLPProjection(nn.Module):
465+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
466+
super().__init__()
467+
from .attention import FeedForward
468+
469+
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
470+
self.norm = nn.LayerNorm(cross_attention_dim)
471+
472+
def forward(self, image_embeds: torch.FloatTensor):
473+
return self.norm(self.ff(image_embeds))
474+
475+
464476
class CombinedTimestepLabelEmbeddings(nn.Module):
465477
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
466478
super().__init__()

tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,25 @@ def test_inpainting(self):
182182

183183
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
184184

185+
def test_text_to_image_full_face(self):
186+
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
187+
pipeline = StableDiffusionPipeline.from_pretrained(
188+
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
189+
)
190+
pipeline.to(torch_device)
191+
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
192+
pipeline.set_ip_adapter_scale(0.7)
193+
194+
inputs = self.get_dummy_inputs()
195+
images = pipeline(**inputs).images
196+
image_slice = images[0, :3, :3, -1].flatten()
197+
198+
expected_slice = np.array(
199+
[0.1706543, 0.1303711, 0.12573242, 0.21777344, 0.14550781, 0.14038086, 0.40820312, 0.41455078, 0.42529297]
200+
)
201+
202+
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
203+
185204

186205
@slow
187206
@require_torch_gpu

0 commit comments

Comments
 (0)