Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def load_ip_adapter(
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
target_blocks = kwargs.pop("target_blocks", ["block"])

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
Expand Down Expand Up @@ -226,7 +227,7 @@ def load_ip_adapter(

# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage, target_blocks=target_blocks)

def set_ip_adapter_scale(self, scale):
"""
Expand Down
11 changes: 8 additions & 3 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us

return image_projection

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False, target_blocks=["block"]):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
Expand Down Expand Up @@ -864,11 +864,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]

with init_context():
selected = any(block_name in name for block_name in target_blocks)

attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
skip=not selected,
)

value_dict = {}
Expand All @@ -887,14 +890,16 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F

return attn_procs

def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False, target_blocks=["block"]):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None

attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
attn_procs = self._convert_ip_adapter_attn_to_diffusers(
state_dicts, low_cpu_mem_usage=low_cpu_mem_usage, target_blocks=target_blocks
)
self.set_attn_processor(attn_procs)

# convert IP-Adapter Image Projection layers to diffusers
Expand Down
98 changes: 51 additions & 47 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,7 @@ class IPAdapterAttnProcessor(nn.Module):
the weight scale of image prompt.
"""

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, skip=False):
super().__init__()

self.hidden_size = hidden_size
Expand All @@ -2117,6 +2117,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
self.skip = skip

if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
Expand Down Expand Up @@ -2208,29 +2209,30 @@ def __call__(
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)
if not self.skip:
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample
current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -2263,7 +2265,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, skip=False):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
Expand All @@ -2283,6 +2285,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.skip = skip

self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
Expand Down Expand Up @@ -2382,36 +2385,37 @@ def __call__(
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
if not self.skip:
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample
current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down