Skip to content

Conversation

@fabiorigano
Copy link
Contributor

@fabiorigano fabiorigano commented Mar 2, 2024

What does this PR do?

Fixes #7014 #6935

  • Switch to PEFT
  • Move to core
  • Add tests
  • Add docs

@yiyixuxu @sayakpaul

Create face embeddings

import torch
import cv2
import numpy as np
from diffusers.utils import load_image
from diffusers import AutoPipelineForText2Image, AutoencoderKL, DDIMScheduler
from insightface.app import FaceAnalysis

image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")

ref_images_embeds = []
ref_unc_images_embeds = []
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
for im in [image1, image2]:
    image = cv2.cvtColor(np.asarray(im), cv2.COLOR_BGR2RGB)
    faces = app.get(image)
    image = torch.from_numpy(faces[0].normed_embedding)
    image_embeds = image.unsqueeze(0)
    uncond_image_embeds = torch.zeros_like(image_embeds)
    ref_images_embeds.append(image_embeds)
    ref_unc_images_embeds.append(uncond_image_embeds)
ref_images_embeds = torch.stack(ref_images_embeds, dim=0)
ref_unc_images_embeds = torch.stack(ref_unc_images_embeds, dim=0)
single_image_embeds = torch.cat([ref_unc_images_embeds, ref_images_embeds], dim=0).to(dtype=torch.float16)

IP Adapter Face ID (SD 1.5)

base_model_path ="SG161222/Realistic_Vision_V4.0_noVAE"

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler
)

pipeline.to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID",
                         subfolder=None, 
                         weight_name="ip-adapter-faceid_sd15.bin", 
                         image_encoder_folder=None)
pipeline.set_ip_adapter_scale(0.7)

pipeline.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(42)

num_images=2
images = pipeline(
    prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
    ip_adapter_image=[single_image_embeds],
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704, 
    generator=generator, 
).images

IP Adapter Face ID XL (SDXL)

base_model_path ="SG161222/RealVisXL_V3.0"

pipeline = AutoPipelineForText2Image.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler
)

pipeline.to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID",
                         subfolder=None, 
                         weight_name="ip-adapter-faceid_sdxl.bin", 
                         image_encoder_folder=None)
pipeline.set_ip_adapter_scale(0.7)

pipeline.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(42)

num_images=2

images = pipeline(
    prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
    ip_adapter_image=[single_image_embeds], guidance_scale=7.5,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=30, num_images_per_prompt=2,
    generator=generator
).images

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refreshing change!

@jfischoff
Copy link

I'm getting the error

table_diffusion_xl.py", line 497, in encode_image
AttributeError: 'NoneType' object has no attribute 'parameters'

when I try to use this.

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Mar 9, 2024

I'm getting the error

table_diffusion_xl.py", line 497, in encode_image
AttributeError: 'NoneType' object has no attribute 'parameters'

when I try to use this.

you cannot use Face ID with SDXL, the current changes only affect the Stable Diffusion pipeline

@fabiorigano
Copy link
Contributor Author

@jfischoff you can use it now, I also updated the example code

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 9, 2024

@fabiorigano is this ready for a review?

@fabiorigano
Copy link
Contributor Author

@yiyixuxu I have to add some checks on the inputs, but I would appreciate your feedback. thanks :)

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Mar 9, 2024

Since both Face ID adapter and Face ID XL don't use an image encoder, I tested the multi-adapter feature by separately extracting and then concatenating the image embeddings of Face ID XL and another IP Adapter, Plus Face SDXL.

I think that prepare_ip_adapter_image_embeds would become too specific for this use case if we have to support an ip_adapter_images containing both images and insightface embeddings.

Here it is the code of the test:

# Create a SDXL pipeline
# ...

# Load sample images
image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")

# Extract Face features using insightface
ref_images = []
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
for im in [image1, image2]:
    image = cv2.cvtColor(np.asarray(im), cv2.COLOR_BGR2RGB)
    faces = app.get(image)
    image = torch.from_numpy(faces[0].normed_embedding)
    ref_images.append(image.unsqueeze(0))
ref_images = torch.cat(ref_images, dim=0)

# Load Face ID XL adapter into the pipeline
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", 
    subfolder=None, 
    weight_name="ip-adapter-faceid_sdxl.bin", 
    image_encoder_folder=None
)

# Generate Face ID image embeddings and save them locally
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
    ip_adapter_image=ref_images,
    ip_adapter_image_embeds=None,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)
torch.save(image_embeds, "faceid_xl.ipadpt")

# Unload ip adapter and lora
# ...

# Load Plus SDXL adapter into the pipeline
pipeline.load_ip_adapter("h94/IP-Adapter", 
    subfolder="sdxl_models", 
    weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors")

# Generate Plus SDXL image embeddings and save them locally
ip_images =[[image1, image2]]
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
    ip_adapter_image=ip_images,
    ip_adapter_image_embeds=None,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)
torch.save(image_embeds, "plus_face_xl.ipadpt")

# Unload the IP adapter
# ...

# Load both IP Adapters
pipeline.load_ip_adapter(["h94/IP-Adapter", "h94/IP-Adapter-FaceID"], 
    subfolder=["sdxl_models", None], 
    weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors", "ip-adapter-faceid_sdxl.bin"]
)
pipeline.set_ip_adapter_scale([0.7]*2)

# Load image embeddings and run inference
generator = torch.Generator(device="cpu").manual_seed(42)

t1 = torch.load("plus_face_xl.ipadpt")
t2 = torch.load("faceid_xl.ipadpt")
t = [t1[0], t2[0]]

images = pipeline(
    prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
    ip_adapter_image_embeds=t, guidance_scale=7.5,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=30, num_images_per_prompt=num_images, width=1024, height=1024, 
    generator=generator
).images

@yiyixuxu
Copy link
Collaborator

@fabiorigano

I think that prepare_ip_adapter_image_embeds would become too specific for this use case if we have to support an ip_adapter_images containing both images and insightface embeddings.

good news is that we do not want to support ip_adapter_image for face-id! :) we should make it clear in the docs

also, to make it easier to test, can you upload the ref_images embedding somewhere, maybe a hf dataset, so that we can just use that as input directly?

# Extract Face features using insightface
ref_images = []
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
for im in [image1, image2]:
    image = cv2.cvtColor(np.asarray(im), cv2.COLOR_BGR2RGB)
    faces = app.get(image)
    image = torch.from_numpy(faces[0].normed_embedding)
    ref_images.append(image.unsqueeze(0))
ref_images = torch.cat(ref_images, dim=0)

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 10, 2024

@fabiorigano

I tested the multi-adapter feature by separately extracting and then concatenating the image embeddings of Face ID XL and another IP Adapter, Plus Face SDXL

can you combine face-id with other ip-adaper models? I thought it required its own attention processor

@fabiorigano
Copy link
Contributor Author

@fabiorigano

I tested the multi-adapter feature by separately extracting and then concatenating the image embeddings of Face ID XL and another IP Adapter, Plus Face SDXL

can you combine face-id with other ip-adaper models? I thought it required its own attention processor

@yiyixuxu I used PEFT to load the LoRA weights, so we don't need additional attention processors :)

@fabiorigano
Copy link
Contributor Author

I uploaded some tensors here https://huggingface.co/datasets/fabiorigano/testing-images/tree/main

Some of my tests and the results (input image embeddings are computed from "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png"):

Face ID SD 1.5 only

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(0.6)
image_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")
images = pipeline(
    prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
    ip_adapter_image_embeds=image_embeds,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=20, num_images_per_prompt=1, width=512, height=704, 
    generator=torch.Generator(device="cpu").manual_seed(0)
).images

Output image
p1_0

Plus Face SD 1.5 only

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.6)
image_embeds  = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/clip_ai_face2.ipadpt")
images = pipeline(
    prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
    ip_adapter_image_embeds=image_embeds,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=20, num_images_per_prompt=1, width=512, height=704, 
    generator=torch.Generator(device="cpu").manual_seed(0)
).images

Output image
plus15_scale0 6

Plus Face SD 1.5 + Face ID SD 1.5

pipeline.load_ip_adapter(["h94/IP-Adapter", "h94/IP-Adapter-FaceID"], subfolder=["models", None], weight_name=["ip-adapter-plus-face_sd15.safetensors", "ip-adapter-faceid_sd15.bin"])
pipeline.set_ip_adapter_scale([0.5, 0.5])
t1 = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/clip_ai_face2.ipadpt")
t2 = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")
image_embeds = [t1[0], t2[0]]
images = pipeline(
    prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
    ip_adapter_image_embeds=image_embeds,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=20, num_images_per_prompt=1, width=512, height=704, 
    generator=torch.Generator(device="cpu").manual_seed(0)
).images

Output image
plus15-faceid15_scale0 5

@fabiorigano fabiorigano changed the title [WIP] Move IP Adapter Face ID to core Move IP Adapter Face ID to core Mar 10, 2024
@fabiorigano fabiorigano requested a review from sayakpaul March 11, 2024 08:03
@fabiorigano
Copy link
Contributor Author

@yiyixuxu it is ready for review

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thnaks!
I left some comments and questions

logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
"image_encoder is not loaded since `image_encoder_folder=None` passed. `ip_adapter_image` is allowed only if you are loading an IP-Adapter Face ID model."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little bit confused here - I thought it was the opposite, i.e. we do not allow using ip_adapter_image with the Face ID model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conceptually Face ID embeddings are image embeddings, but the tensor as it is doesn't have the unconditioned part, so in encode_image it is updated as is expected.

do you think it is better to leave this to the user?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - let's make it clear on the doc how to create the ip_adapter_image_embedding for face-id

]
}
)
key_id += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there more than one face-id checkpoints right now? does it make sense for us to support more than one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Face ID and Face ID XL are both supported by this PR
Face ID Plus models have different image projection layers

Comment on lines 967 to 976
"""Forward pass.
Args:
----
id_embeds (torch.Tensor): Input Tensor (ID embeds).
Returns:
-------
torch.Tensor: Output Tensor.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to follow our doc-string format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will update it (also IPAdapterPlusImageProjection for code consistency)?

Comment on lines 950 to 961
nn.LayerNorm(embed_dims),
nn.LayerNorm(embed_dims),
Attention(
query_dim=embed_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
),
nn.Sequential(
nn.LayerNorm(embed_dims),
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong opinions here but perhaps we could create a small block consisting of these layers and use that block here instead. Then

        for ln0, ln1, attn, ff in self.layers:
            residual = latents

            encoder_hidden_states = ln0(x)
            latents = ln1(latents)
            encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
            latents = attn(latents, encoder_hidden_states) + residual
            latents = ff(latents) + latents

could become:

for block in self.blocks:
    ...

If the checkpoint needs to be rejigged to match this structure, we could have a load state dict hook to deal with the modifications. But I would wait for @yiyixuxu to comment further before making any changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice but I don't think it is a big deal
if it requires a lot of effort from @fabiorigano I don't think it's worth it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah totally fine by me. It was just a suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, I added IPAdapterPlusImageProjectionBlock, let me know if it works for you

# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
extra_loras = unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. To reduce the maintenance burden and to promote better readability, perhaps we could separate out the LoRA-related code from _load_ip_adapter_weights()?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Left a couple of comments.

Comment on lines 231 to 238
extra_loras = unet._load_ip_adapter_loras(state_dicts)
if extra_loras != {}:
# apply the IP Adapter Face ID LoRA weights
peft_config = getattr(unet, "peft_config", {})
for k, lora in extra_loras.items():
if f"faceid_{k}" not in peft_config:
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sleek!

heads=heads,
id_embeddings_dim=id_embeddings_dim,
)
print(state_dict.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to go away.

Comment on lines 822 to 823
print(updated_state_dict.keys())
print(image_projection.state_dict().keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to go away.

max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4

def test_text_to_image_face_id(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a fast test as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is looking quite nice to me. Thanks a lot for working on it. Also, do we need to add a check like so

if not USE_PEFT_BACKEND:

when there's a call to use the IP Adapter Face ID weights?

I will defer to @yiyixuxu to merge this. I would just run the concerned slow tests on our CI infrastructure as well to ensure nothing's breaking. @yiyixuxu could you do that before merging?

@fabiorigano
Copy link
Contributor Author

Also, do we need to add a check like so

if not USE_PEFT_BACKEND:

when there's a call to use the IP Adapter Face ID weights?

I will add it

@yiyixuxu yiyixuxu merged commit b5c8b55 into huggingface:main Apr 19, 2024
@yiyixuxu
Copy link
Collaborator

great work as always! thanks a lot :) @fabiorigano

@fabiorigano fabiorigano deleted the faceidcore branch April 19, 2024 01:27
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Switch to peft and multi proj layers

* Move Face ID loading and inference to core

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Move ip_adapter_face_id from community to main pipeline

5 participants