-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Load Kohya-ss style LoRAs with auxilary states #4147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
from huggingface_hub import hf_hub_download | ||
from torch import nn | ||
|
||
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer | ||
from .utils import ( | ||
DIFFUSERS_CACHE, | ||
HF_HUB_OFFLINE, | ||
|
@@ -56,6 +57,7 @@ | |
|
||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" | ||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" | ||
TOTAL_EXAMPLE_KEYS = 5 | ||
|
||
TEXT_INVERSION_NAME = "learned_embeds.bin" | ||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" | ||
|
@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder): | |
return attn_modules | ||
|
||
|
||
def text_encoder_mlp_modules(text_encoder): | ||
mlp_modules = [] | ||
|
||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): | ||
for i, layer in enumerate(text_encoder.text_model.encoder.layers): | ||
mlp_mod = layer.mlp | ||
name = f"text_model.encoder.layers.{i}.mlp" | ||
mlp_modules.append((name, mlp_mod)) | ||
else: | ||
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") | ||
|
||
return mlp_modules | ||
|
||
|
||
def text_encoder_lora_state_dict(text_encoder): | ||
state_dict = {} | ||
|
||
|
@@ -304,6 +320,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
|
||
# fill attn processors | ||
attn_processors = {} | ||
non_attn_lora_layers = [] | ||
|
||
is_lora = all("lora" in k for k in state_dict.keys()) | ||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) | ||
|
@@ -327,13 +344,33 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
lora_grouped_dict[attn_processor_key][sub_key] = value | ||
|
||
for key, value_dict in lora_grouped_dict.items(): | ||
rank = value_dict["to_k_lora.down.weight"].shape[0] | ||
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] | ||
|
||
attn_processor = self | ||
for sub_key in key.split("."): | ||
attn_processor = getattr(attn_processor, sub_key) | ||
|
||
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers | ||
# or add_{k,v,q,out_proj}_proj_lora layers. | ||
if "lora.down.weight" in value_dict: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rank = value_dict["lora.down.weight"].shape[0] | ||
hidden_size = value_dict["lora.up.weight"].shape[0] | ||
isidentical marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if isinstance(attn_processor, LoRACompatibleConv): | ||
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) | ||
elif isinstance(attn_processor, LoRACompatibleLinear): | ||
lora = LoRALinearLayer( | ||
attn_processor.in_features, attn_processor.out_features, rank, network_alpha | ||
) | ||
else: | ||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") | ||
|
||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} | ||
lora.load_state_dict(value_dict) | ||
non_attn_lora_layers.append((attn_processor, lora)) | ||
continue | ||
|
||
rank = value_dict["to_k_lora.down.weight"].shape[0] | ||
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] | ||
|
||
if isinstance( | ||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) | ||
): | ||
|
@@ -390,10 +427,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
|
||
# set correct dtype & device | ||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} | ||
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers] | ||
|
||
# set layers | ||
self.set_attn_processor(attn_processors) | ||
|
||
# set ff layers | ||
for target_module, lora_layer in non_attn_lora_layers: | ||
if hasattr(target_module, "set_lora_layer"): | ||
target_module.set_lora_layer(lora_layer) | ||
|
||
def save_attn_procs( | ||
self, | ||
save_directory: Union[str, os.PathLike], | ||
|
@@ -840,7 +883,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) | ||
self.load_lora_into_text_encoder( | ||
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale | ||
state_dict, | ||
network_alpha=network_alpha, | ||
text_encoder=self.text_encoder, | ||
lora_scale=self.lora_scale, | ||
) | ||
|
||
@classmethod | ||
|
@@ -1049,6 +1095,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr | |
text_encoder_lora_state_dict = { | ||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys | ||
} | ||
|
||
if len(text_encoder_lora_state_dict) > 0: | ||
logger.info(f"Loading {prefix}.") | ||
|
||
|
@@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr | |
rank = text_encoder_lora_state_dict[ | ||
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" | ||
].shape[1] | ||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok with this for now, but it would be really nice to avoid this if possible in the future. I think the meta point to think about here is once we have checks like this at any point inside the code, we have to now consider what are implications for any state dict checking or changing code any time we touch a model definition. This part specifically is a model definition from a separate library which is even more hairy to be checking. We're lucky that we know the specific way transformers is written is that they very rarely change model definitions once they're written, but in general that's not something that we should rely on. I think a good analogue is consider applications on your computer that serialize their state as locally stored files. That's all state dicts are, an application serialization format. Almost all applications will say, you should not make any assumptions about the format or make modifications to our files we store. If they do say files are user editable, they're usually very explicitly documented where as our state dict formats are implicitly documented through a combination of code in different libraries and how diffusers elects to monkey patch updated model definitions |
||
|
||
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) | ||
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp) | ||
|
||
# set correct dtype & device | ||
text_encoder_lora_state_dict = { | ||
|
@@ -1125,8 +1173,21 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): | |
attn_module.v_proj = attn_module.v_proj.regular_linear_layer | ||
attn_module.out_proj = attn_module.out_proj.regular_linear_layer | ||
|
||
for _, mlp_module in text_encoder_mlp_modules(text_encoder): | ||
if isinstance(mlp_module.fc1, PatchedLoraProjection): | ||
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer | ||
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer | ||
isidentical marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None): | ||
def _modify_text_encoder( | ||
cls, | ||
text_encoder, | ||
lora_scale=1, | ||
network_alpha=None, | ||
rank=4, | ||
dtype=None, | ||
patch_mlp=False, | ||
): | ||
r""" | ||
Monkey-patches the forward passes of attention modules of the text encoder. | ||
""" | ||
|
@@ -1157,6 +1218,18 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra | |
) | ||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) | ||
|
||
if patch_mlp: | ||
for _, mlp_module in text_encoder_mlp_modules(text_encoder): | ||
mlp_module.fc1 = PatchedLoraProjection( | ||
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype | ||
) | ||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) | ||
|
||
mlp_module.fc2 = PatchedLoraProjection( | ||
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype | ||
) | ||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) | ||
|
||
return lora_parameters | ||
|
||
@classmethod | ||
|
@@ -1261,9 +1334,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): | |
unet_state_dict = {} | ||
te_state_dict = {} | ||
network_alpha = None | ||
unloaded_keys = [] | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for key, value in state_dict.items(): | ||
if "lora_down" in key: | ||
if "hada" in key or "skip" in key: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unloaded_keys.append(key) | ||
elif "lora_down" in key: | ||
lora_name = key.split(".")[0] | ||
lora_name_up = lora_name + ".lora_up.weight" | ||
lora_name_alpha = lora_name + ".alpha" | ||
|
@@ -1284,12 +1360,21 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): | |
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") | ||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") | ||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") | ||
diffusers_name = diffusers_name.replace("proj.in", "proj_in") | ||
diffusers_name = diffusers_name.replace("proj.out", "proj_out") | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if "transformer_blocks" in diffusers_name: | ||
if "attn1" in diffusers_name or "attn2" in diffusers_name: | ||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor") | ||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor") | ||
unet_state_dict[diffusers_name] = value | ||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
elif "ff" in diffusers_name: | ||
unet_state_dict[diffusers_name] = value | ||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")): | ||
unet_state_dict[diffusers_name] = value | ||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
|
||
elif lora_name.startswith("lora_te_"): | ||
diffusers_name = key.replace("lora_te_", "").replace("_", ".") | ||
diffusers_name = diffusers_name.replace("text.model", "text_model") | ||
|
@@ -1301,6 +1386,19 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): | |
if "self_attn" in diffusers_name: | ||
te_state_dict[diffusers_name] = value | ||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
elif "mlp" in diffusers_name: | ||
# Be aware that this is the new diffusers convention and the rest of the code might | ||
# not utilize it yet. | ||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") | ||
te_state_dict[diffusers_name] = value | ||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
|
||
logger.info("Kohya-style checkpoint detected.") | ||
if len(unloaded_keys) > 0: | ||
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) | ||
logger.warning( | ||
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." | ||
) | ||
|
||
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} | ||
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} | ||
|
@@ -1346,6 +1444,10 @@ def unload_lora_weights(self): | |
[attention_proc_class] = unet_attention_classes | ||
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]()) | ||
|
||
for _, module in self.unet.named_modules(): | ||
if hasattr(module, "set_lora_layer"): | ||
module.set_lora_layer(None) | ||
|
||
# Safe to call the following regardless of LoRA. | ||
self._remove_text_encoder_monkey_patch() | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.