Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,39 @@ export_to_gif(frames, "gummy_bear.gif")
> [!TIP]
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.

All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @stevhliu here for awareness

I added a section to ip-adapter guide here. Let me know if you have any comments. If editing in a separate PR is easier, feel free to do so!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another very good use case for ip_adapter_image_embeds is probably the multi-ip-adapter https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter#multi-ip-adapter

a common practice is to use a folder of 10+ images for styling, and you would use the same styling images everywhere to create a consistent style, so it would be nice to create an image embedding for these style images, so you don't have to load a bunch of same images from a folder and encode them each time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should definitely add that example motivating the use case. WDYT @asomoza?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll edit it in a separate PR, and I can also make a mention of ip_adapter_image_embeds in the multi IP-Adapter section 🙂

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should definitely add that example motivating the use case. WDYT @asomoza?

yeah this is specially helpful when you use a lot of images and multiple ip adapters, you just need to save the embeddings making it a lot easier to replicate and saves a lot of space if you use high quality images.

I'll try to do one with a style and a character and see how it goes, but to see the real potential of this we'll also need controlnet and ip adapter masking so the best use case would be a full demo with all of this incorporated.


```py
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
ip_adapter_image=image,
ip_adapter_image_embeds=None,
device="cuda",
num_images_per_prompt=1,
do_classifier_free_guidance=True,
)

torch.save(image_embeds, "image_embeds.ipadpt")
```

Load the image embedding and pass it to the pipeline as `ip_adapter_image_embeds`

> [!TIP]
> ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.

```py
image_embeds = torch.load("image_embeds.ipadpt")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know where this is coming from. Let's include a snippet to download that and explicitly mention that it's coming from ComfyUI.

images = pipeline(
prompt="a polar bear sitting in a chair drinking a milkshake",
ip_adapter_image_embeds=image_embeds,
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
num_inference_steps=100,
generator=generator,
).images
```

> [!TIP]
> If you use IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`, you can choose not to load an image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`.

## Specific use cases

IP-Adapter's image prompting and compatibility with other adapters and models makes it a versatile tool for a variety of use cases. This section covers some of the more popular applications of IP-Adapter, and we can't wait to see what you come up with!
Expand Down
6 changes: 4 additions & 2 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,10 @@ def __call__(
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
conditioning_frames (`List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
are specified, images must be passed as a list such that each element of the list can be correctly
Expand Down
6 changes: 4 additions & 2 deletions examples/community/pipeline_animatediff_img2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down
51 changes: 37 additions & 14 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import torch
from huggingface_hub.utils import validate_hf_hub_args
Expand Down Expand Up @@ -52,11 +52,12 @@ def load_ip_adapter(
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
Expand All @@ -65,7 +66,18 @@ def load_ip_adapter(
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
If a list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
for example, `image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
Expand All @@ -87,8 +99,6 @@ def load_ip_adapter(
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Expand Down Expand Up @@ -184,16 +194,29 @@ def load_ip_adapter(

# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
low_cpu_mem_usage=low_cpu_mem_usage,
).to(self.device, dtype=self.dtype)
self.register_modules(image_encoder=image_encoder)
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠

image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
).to(self.device, dtype=self.dtype)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
)

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
Expand Down
38 changes: 32 additions & 6 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
Expand All @@ -394,13 +394,23 @@ def prepare_ip_adapter_image_embeds(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
if do_classifier_free_guidance:
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding do_classifier_free_guidance as an argument to prepare_ip_adapter_imabe_embeds so we can use this pipeline method to save image embeddings

image_embeds = pipeline.prepare_ip_adapter_image_embeds(
    ip_adapter_image=image,
    ip_adapter_image_embeds=None,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)

torch.save(image_embeds, "image_embeds.ipadpt")

single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)

return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
Expand Down Expand Up @@ -494,6 +504,16 @@ def check_inputs(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
Comment on lines +507 to +515
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need any checks on the shapes to conform to what's needed for classifier-free guidance?


# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
Expand Down Expand Up @@ -612,8 +632,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down Expand Up @@ -717,7 +739,11 @@ def __call__(

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_videos_per_prompt,
self.do_classifier_free_guidance,
)

# 4. Prepare timesteps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
Expand All @@ -472,13 +472,23 @@ def prepare_ip_adapter_image_embeds(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)

return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
Expand Down Expand Up @@ -523,6 +533,8 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
Expand Down Expand Up @@ -567,6 +579,21 @@ def check_inputs(
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

def get_timesteps(self, num_inference_steps, timesteps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
Expand Down Expand Up @@ -765,8 +792,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand Down Expand Up @@ -814,6 +843,8 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
video=video,
latents=latents,
ip_adapter_image=ip_adapter_image,
ip_adapter_image_embeds=ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

Expand Down Expand Up @@ -855,7 +886,11 @@ def __call__(

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_videos_per_prompt,
self.do_classifier_free_guidance,
)

# 4. Prepare timesteps
Expand Down
Loading