Skip to content

Fix running LoRA with xformers #2286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __init__(
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
)

if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
Expand Down Expand Up @@ -138,9 +142,28 @@ def set_use_memory_efficient_attention_xformers(
except Exception as e:
raise e

processor = XFormersCrossAttnProcessor(attention_op=attention_op)
if is_lora:
processor = LoRAXFormersCrossAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
else:
processor = CrossAttnProcessor()
if is_lora:
processor = LoRACrossAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = CrossAttnProcessor()

self.set_processor(processor)

Expand Down Expand Up @@ -324,6 +347,10 @@ class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
Expand Down Expand Up @@ -437,9 +464,14 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No


class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4):
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
Expand All @@ -462,7 +494,9 @@ def __call__(
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
Expand Down Expand Up @@ -595,4 +629,6 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
SlicedAttnProcessor,
CrossAttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
LoRACrossAttnProcessor,
LoRAXFormersCrossAttnProcessor,
]
29 changes: 29 additions & 0 deletions tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,35 @@ def test_lora_on_off(self):
assert (sample - new_sample).abs().max() < 1e-4
assert (sample - old_sample).abs().max() < 1e-4

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_lora_xformers_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["attention_head_dim"] = (8, 16)

torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)

# default
with torch.no_grad():
sample = model(**inputs_dict).sample

model.enable_xformers_memory_efficient_attention()
on_sample = model(**inputs_dict).sample

model.disable_xformers_memory_efficient_attention()
off_sample = model(**inputs_dict).sample

assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4


@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
Expand Down