Skip to content

IP_Adapters shape mismatch when generating images on v0.25.0_dev using SDXL? #6162

@salahzoubi

Description

@salahzoubi

Describe the bug

When using ip_adapters with controlnets and sdxl (whether sdxl-turbo or sdxl1.0) you get a shape mismatch when generating images. If you remove the ip_adapter things start working again. Not sure what the problem might be here?

Reproduction

Here's what I'm doing:


from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, AutoencoderTiny, ControlNetModel
import torch
from PIL import Image

net_id = "diffusers/controlnet-canny-sdxl-1.0"
controlnet = ControlNetModel.from_pretrained(net_id, torch_dtype=torch.float16)

#stabilityai/sdxl-turbo
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", vae=vae, torch_dtype=torch.float16, controlnet=controlnet)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
# pipe.image_encoder = CLIPVisionModelWithProjection.from_pretrained("image_encoder_xl/")

pipe = pipe.to("cuda")

control_image = load_image("1.png")
ip_image = load_image("2.png")
prompt = "person having fun"

images = pipe(
    prompt=prompt, 
    image=control_image,
    ip_adapter_image=ip_image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=4,
).images[0]


Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[25], line 1
----> 1 images = pipe(
      2     prompt='cute anime girl smiling, girl smile, laughing, cute', 
      3     image=control_image,
      4     ip_adapter_image=ip_image,
      5     negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
      6     num_inference_steps=4,
      7 ).images[0]

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py:1208, in StableDiffusionXLPipeline.__call__(self, prompt, prompt_2, height, width, num_inference_steps, timesteps, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, output_type, return_dict, cross_attention_kwargs, guidance_rescale, original_size, crops_coords_top_left, target_size, negative_original_size, negative_crops_coords_top_left, negative_target_size, clip_skip, callback_on_step_end, callback_on_step_end_tensor_inputs, **kwargs)
   1206 if ip_adapter_image is not None:
   1207     added_cond_kwargs["image_embeds"] = image_embeds
-> 1208 noise_pred = self.unet(
   1209     latent_model_input,
   1210     t,
   1211     encoder_hidden_states=prompt_embeds,
   1212     timestep_cond=timestep_cond,
   1213     cross_attention_kwargs=self.cross_attention_kwargs,
   1214     added_cond_kwargs=added_cond_kwargs,
   1215     return_dict=False,
   1216 )[0]
   1218 # perform guidance
   1219 if self.do_classifier_free_guidance:

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/diffusers/src/diffusers/models/unet_2d_condition.py:1068, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1064         raise ValueError(
   1065             f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
   1066         )
   1067     image_embeds = added_cond_kwargs.get("image_embeds")
-> 1068     image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
   1069     encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
   1071 # 2. pre-process

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/diffusers/src/diffusers/models/embeddings.py:881, in Resampler.forward(self, x)
    869 """Forward pass.
    870 
    871 Args:
   (...)
    877     torch.Tensor: Output Tensor.
    878 """
    879 latents = self.latents.repeat(x.size(0), 1, 1)
--> 881 x = self.proj_in(x)
    883 for ln0, ln1, attn, ff in self.layers:
    884     residual = latents

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/envs/sd_diff/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (514x1664 and 1280x1280)

System Info

  • diffusers version: 0.25.0.dev0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.1.1 (True)
  • Huggingface_hub version: 0.19.4
  • Transformers version: 4.36.0
  • Accelerate version: 0.24.1
  • xFormers version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: False

Who can help?

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions