Skip to content

IP-Adapter attention masking warning logs #7033

@Honey-666

Description

@Honey-666

if len(unused_kwargs) > 0:

hi!
Is the call method of the AttnProcessor2_0 class missing the parameter ip-adapter_masks? When I use it, it will print a warning log cross_attention_kwargs ['ip_adapter_masks'] are not expected by AttnProcessor2_0 and will be ignored.

os:

diffusers==diffusers-0.27.0.dev0

this is my code:

import torch
from PIL import Image
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
from diffusers.image_processor import IPAdapterMaskProcessor
from transformers import CLIPVisionModelWithProjection

model_path = '../../../aidazuo/models/Stable-diffusion/sd_xl_base_1.0'
vae_path = '../../../aidazuo/models/VAE/sdxl-vae-fp16-fix'
ip_adapter_path = '../../../aidazuo/models/IP-Adapter'
ip_img_path1 = '../../../aidazuo/jupyter-script/test-img/ip_mask_girl1.png'
ip_img_path2 = '../../../aidazuo/jupyter-script/test-img/ip_mask_girl2.png'
mask_path1 = '../../../aidazuo/jupyter-script/test-img/ip_mask_mask1.png'
mask_path2 = '../../../aidazuo/jupyter-script/test-img/ip_mask_mask2.png'
ip_adapter_img1 = Image.open(ip_img_path1)
ip_adapter_img2 = Image.open(ip_img_path2)
mask_img1 = Image.open(mask_path1)
mask_img2 = Image.open(mask_path2)

processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask_img1, mask_img2], height=1024, width=1024)
# vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_path,
                                                              subfolder='models/image_encoder',
                                                              torch_dtype=torch.float16).to('cuda')

pipe = StableDiffusionXLPipeline.from_pretrained(
    model_path,
    safety_checker=None,
    variant="fp16",
    torch_dtype=torch.float16,
    image_encoder=image_encoder
).to("cuda")

adapter_file_lst = ["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2
adapter_weight_lst = [0.7] * 2
adapter_img_lst = [ip_adapter_img1, ip_adapter_img2]

pipe.load_ip_adapter(ip_adapter_path, subfolder="sdxl_models", weight_name=adapter_file_lst)
pipe.set_ip_adapter_scale(adapter_weight_lst)
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)


images = pipe(
    prompt='2 girls',
    ip_adapter_image=adapter_img_lst,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    # num_inference_steps=30,
    num_images_per_prompt=1,
    width=1024,
    height=1024,
    cross_attention_kwargs={"ip_adapter_masks": masks}
).images
pipe.unload_ip_adapter()

for img in images:
    img.show()

adapter-attention-masking-warning

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions