Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
91b34ab
throw warning when more than one lora is attempted to be fused.
sayakpaul Aug 30, 2023
c9eeb78
introduce support of lora scale during fusion.
sayakpaul Aug 30, 2023
37692b1
change test name
sayakpaul Aug 30, 2023
cfd19a5
changes
sayakpaul Aug 30, 2023
8a9dad0
change to _lora_scale
sayakpaul Aug 30, 2023
ed3b37a
lora_scale to call whenever applicable.
sayakpaul Aug 30, 2023
b86a8f6
debugging
sayakpaul Aug 30, 2023
80839d6
lora_scale additional.
sayakpaul Aug 30, 2023
2ed1f2a
cross_attention_kwargs
sayakpaul Aug 30, 2023
3967da8
lora_scale -> scale.
sayakpaul Aug 30, 2023
e24fd70
lora_scale fix
sayakpaul Aug 30, 2023
21e765b
lora_scale in patched projection.
sayakpaul Aug 30, 2023
9678ed2
debugging
sayakpaul Aug 30, 2023
acbbb4d
debugging
sayakpaul Aug 30, 2023
0c2dad4
debugging
sayakpaul Aug 30, 2023
6269412
debugging
sayakpaul Aug 30, 2023
cc0c7ec
debugging
sayakpaul Aug 30, 2023
b357ffc
debugging
sayakpaul Aug 30, 2023
8495b43
debugging
sayakpaul Aug 30, 2023
96fc1af
debugging
sayakpaul Aug 30, 2023
0c501d3
debugging
sayakpaul Aug 30, 2023
016d3e9
debugging
sayakpaul Aug 30, 2023
910e96b
debugging
sayakpaul Aug 30, 2023
de159da
debugging
sayakpaul Aug 30, 2023
4ee8dbf
debugging
sayakpaul Aug 30, 2023
cd9ac47
styling.
sayakpaul Aug 30, 2023
1cd983f
debugging
sayakpaul Aug 30, 2023
860a374
debugging
sayakpaul Aug 30, 2023
1b2346c
debugging
sayakpaul Aug 30, 2023
583da5f
debugging
sayakpaul Aug 30, 2023
77f6459
debugging
sayakpaul Aug 30, 2023
ec67361
debugging
sayakpaul Aug 30, 2023
6c9c5dc
debugging
sayakpaul Aug 30, 2023
d7b35d4
debugging
sayakpaul Aug 30, 2023
e601d2b
debugging
sayakpaul Aug 30, 2023
35148d0
debugging
sayakpaul Aug 30, 2023
55efe9c
debugging
sayakpaul Aug 30, 2023
0d7b3df
debugging
sayakpaul Aug 30, 2023
cdc7963
remove unneeded prints.
sayakpaul Aug 30, 2023
2a3e358
remove unneeded prints.
sayakpaul Aug 30, 2023
42c2c0a
assign cross_attention_kwargs.
sayakpaul Aug 30, 2023
98e6eca
debugging
sayakpaul Aug 30, 2023
03abb4c
debugging
sayakpaul Aug 30, 2023
32a175f
debugging
sayakpaul Aug 30, 2023
ef1ad84
debugging
sayakpaul Aug 30, 2023
9a759b9
debugging
sayakpaul Aug 30, 2023
369a53f
debugging
sayakpaul Aug 30, 2023
833fd35
debugging
sayakpaul Aug 30, 2023
d8371ab
debugging
sayakpaul Aug 30, 2023
a5925ab
debugging
sayakpaul Aug 30, 2023
d3d6ab1
debugging
sayakpaul Aug 30, 2023
8c0b584
debugging
sayakpaul Aug 30, 2023
43d6c8d
debugging
sayakpaul Aug 30, 2023
caa8625
debugging
sayakpaul Aug 30, 2023
38cbe46
debugging
sayakpaul Aug 30, 2023
b29e025
debugging
sayakpaul Aug 30, 2023
b275947
debugging
sayakpaul Aug 30, 2023
d8b4bf7
debugging
sayakpaul Aug 30, 2023
a3df6cd
debugging
sayakpaul Aug 30, 2023
265d5f4
debugging
sayakpaul Aug 30, 2023
00167be
Merge branch 'main' into lora-improvements-pt3
sayakpaul Aug 30, 2023
7d34884
clean up.
sayakpaul Aug 30, 2023
9dee7d4
refactor scale retrieval logic a bit.
sayakpaul Aug 31, 2023
f81f77d
fix nonetypw
sayakpaul Aug 31, 2023
92e1194
fix: tests
sayakpaul Aug 31, 2023
4511f48
add more tests
sayakpaul Aug 31, 2023
6667e68
more fixes.
sayakpaul Aug 31, 2023
b941b88
figure out a way to pass lora_scale.
sayakpaul Aug 31, 2023
9705cc2
Apply suggestions from code review
sayakpaul Sep 4, 2023
bebab12
unify the retrieval logic of lora_scale.
sayakpaul Sep 4, 2023
81f7ddf
move adjust_lora_scale_text_encoder to lora.py.
sayakpaul Sep 4, 2023
e2c835c
introduce dynamic adjustment lora scale support to sd
sayakpaul Sep 4, 2023
ca48db6
Merge branch 'main' into lora-improvements-pt3
sayakpaul Sep 4, 2023
f2026ac
fix up copies
sayakpaul Sep 4, 2023
7444896
Empty-Commit
sayakpaul Sep 4, 2023
e60f450
add: test to check fusion equivalence on different scales.
sayakpaul Sep 4, 2023
bf1052b
handle lora fusion warning.
sayakpaul Sep 4, 2023
4733384
make lora smaller
patrickvonplaten Sep 4, 2023
dabdd58
make lora smaller
patrickvonplaten Sep 4, 2023
51824c7
make lora smaller
patrickvonplaten Sep 4, 2023
972c8e8
Merge branch 'main' into lora-improvements-pt3
patrickvonplaten Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):

return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

def _fuse_lora(self):
def _fuse_lora(self, lora_scale=1.0):
if self.lora_linear_layer is None:
return

Expand All @@ -108,7 +108,7 @@ def _fuse_lora(self):
if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
Expand All @@ -117,6 +117,7 @@ def _fuse_lora(self):
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self.lora_scale = lora_scale

def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
Expand All @@ -128,16 +129,19 @@ def _unfuse_lora(self):
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()

unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
self.w_down = None

def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))


def text_encoder_attn_modules(text_encoder):
Expand Down Expand Up @@ -576,12 +580,13 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def fuse_lora(self):
def fuse_lora(self, lora_scale=1.0):
self.lora_scale = lora_scale
self.apply(self._fuse_lora_apply)

def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora()
module._fuse_lora(self.lora_scale)

def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
Expand Down Expand Up @@ -924,6 +929,7 @@ class LoraLoaderMixin:
"""
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
Expand Down Expand Up @@ -1807,7 +1813,7 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.

Expand All @@ -1822,22 +1828,31 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
fuse_text_encoder (`bool`, defaults to `True`):
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
if self.num_fused_loras > 1:
logger.warn(
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
)

if fuse_unet:
self.unet.fuse_lora()
self.unet.fuse_lora(lora_scale)

def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora()
attn_module.k_proj._fuse_lora()
attn_module.v_proj._fuse_lora()
attn_module.out_proj._fuse_lora()
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora()
mlp_module.fc2._fuse_lora()
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)

if fuse_text_encoder:
if hasattr(self, "text_encoder"):
Expand Down Expand Up @@ -1884,6 +1899,8 @@ def unfuse_text_encoder_lora(text_encoder):
if hasattr(self, "text_encoder_2"):
unfuse_text_encoder_lora(self.text_encoder_2)

self.num_fused_loras -= 1


class FromSingleFileMixin:
"""
Expand Down
33 changes: 21 additions & 12 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def forward(
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
# 0. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
Expand All @@ -187,7 +187,10 @@ def forward(
else:
norm_hidden_states = self.norm1(hidden_states)

# 0. Prepare GLIGEN inputs
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0

# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)

Expand All @@ -201,12 +204,12 @@ def forward(
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states

# 1.5 GLIGEN Control
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 1.5 ends
# 2.5 ends

# 2. Cross-Attention
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
Expand All @@ -220,7 +223,7 @@ def forward(
)
hidden_states = attn_output + hidden_states

# 3. Feed-forward
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)

if self.use_ada_layer_norm_zero:
Expand All @@ -235,11 +238,14 @@ def forward(

num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = self.ff(norm_hidden_states, scale=lora_scale)

if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
Expand Down Expand Up @@ -295,9 +301,12 @@ def __init__(
if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states):
def forward(self, hidden_states, scale: float = 1.0):
for module in self.net:
hidden_states = module(hidden_states)
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
return hidden_states


Expand Down Expand Up @@ -342,8 +351,8 @@ def gelu(self, gate):
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


Expand Down
44 changes: 22 additions & 22 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, lora_scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
Expand All @@ -589,7 +589,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -722,17 +722,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, lora_scale=scale)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, lora_scale=scale)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states, lora_scale=scale)
value = attn.to_v(hidden_states, lora_scale=scale)
key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
Expand All @@ -746,7 +746,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -782,7 +782,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query, out_dim=4)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
Expand All @@ -791,8 +791,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states, lora_scale=scale)
value = attn.to_v(hidden_states, lora_scale=scale)
key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
Expand All @@ -809,7 +809,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])

# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -937,15 +937,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, lora_scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)

query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
Expand All @@ -958,7 +958,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -1015,15 +1015,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states, lora_scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand All @@ -1043,7 +1043,7 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down
Loading