Skip to content

Load Kohya-ss style LoRAs with auxilary states #4147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/en/training/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License.

<Tip warning={true}>

Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage.
This is an experimental feature. Its APIs can change in future.

</Tip>

Expand Down Expand Up @@ -286,6 +285,8 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pip

## Supporting A1111 themed LoRA checkpoints from Diffusers

This support was made possible because of our amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical).

To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
Expand Down
116 changes: 109 additions & 7 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from huggingface_hub import hf_hub_download
from torch import nn

from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
Expand Down Expand Up @@ -56,6 +57,7 @@

LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TOTAL_EXAMPLE_KEYS = 5

TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
Expand Down Expand Up @@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules


def text_encoder_mlp_modules(text_encoder):
mlp_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")

return mlp_modules


def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

Expand Down Expand Up @@ -304,6 +320,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

# fill attn processors
attn_processors = {}
non_attn_lora_layers = []

is_lora = all("lora" in k for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
Expand All @@ -327,13 +344,33 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
lora_grouped_dict[attn_processor_key][sub_key] = value

for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)

# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")

value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora))
continue

rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
Expand Down Expand Up @@ -390,10 +427,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

# set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers]

# set layers
self.set_attn_processor(attn_processors)

# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
if hasattr(target_module, "set_lora_layer"):
target_module.set_lora_layer(lora_layer)

def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
Expand Down Expand Up @@ -840,7 +883,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
self.load_lora_into_text_encoder(
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
)

@classmethod
Expand Down Expand Up @@ -1049,6 +1095,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}

if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")

Expand Down Expand Up @@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
rank = text_encoder_lora_state_dict[
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok with this for now, but it would be really nice to avoid this if possible in the future. I think the meta point to think about here is once we have checks like this at any point inside the code, we have to now consider what are implications for any state dict checking or changing code any time we touch a model definition.

This part specifically is a model definition from a separate library which is even more hairy to be checking. We're lucky that we know the specific way transformers is written is that they very rarely change model definitions once they're written, but in general that's not something that we should rely on.

I think a good analogue is consider applications on your computer that serialize their state as locally stored files. That's all state dicts are, an application serialization format. Almost all applications will say, you should not make any assumptions about the format or make modifications to our files we store. If they do say files are user editable, they're usually very explicitly documented where as our state dict formats are implicitly documented through a combination of code in different libraries and how diffusers elects to monkey patch updated model definitions


cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)

# set correct dtype & device
text_encoder_lora_state_dict = {
Expand Down Expand Up @@ -1125,8 +1173,21 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
attn_module.out_proj = attn_module.out_proj.regular_linear_layer

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer

@classmethod
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
def _modify_text_encoder(
cls,
text_encoder,
lora_scale=1,
network_alpha=None,
rank=4,
dtype=None,
patch_mlp=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
Expand Down Expand Up @@ -1157,6 +1218,18 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())

if patch_mlp:
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())

mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())

return lora_parameters

@classmethod
Expand Down Expand Up @@ -1261,9 +1334,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None
unloaded_keys = []

for key, value in state_dict.items():
if "lora_down" in key:
if "hada" in key or "skip" in key:
unloaded_keys.append(key)
elif "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
Expand All @@ -1284,12 +1360,21 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]

elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
Expand All @@ -1301,6 +1386,19 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]

logger.info("Kohya-style checkpoint detected.")
if len(unloaded_keys) > 0:
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
logger.warning(
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
)

unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
Expand Down Expand Up @@ -1346,6 +1444,10 @@ def unload_lora_weights(self):
[attention_proc_class] = unet_attention_classes
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())

for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .activations import get_activation
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear


@maybe_allow_in_graph
Expand Down Expand Up @@ -245,7 +246,7 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
Expand Down Expand Up @@ -289,7 +290,7 @@ class GEGLU(nn.Module):

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)

def gelu(self, gate):
if gate.device.type != "mps":
Expand Down
31 changes: 1 addition & 30 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ..utils import deprecate, logging, maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .lora import LoRALinearLayer


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -505,36 +506,6 @@ def __call__(
return hidden_states


class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()

if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)

def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype

down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)

if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

return up_hidden_states.to(orig_dtype)


class LoRAAttnProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism.
Expand Down
Loading