Skip to content

[LoRA] Enabling limited LoRA support for text encoder #2882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
### UNet2DConditionLoadersMixin

[[autodoc]] loaders.UNet2DConditionLoadersMixin

### TextEncoderLoRAMixin

[[autodoc]] loaders.TextEncoderLoRAMixin
317 changes: 316 additions & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@
from typing import Callable, Dict, Union

import torch
import torch.nn as nn

from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
TEXT_ENCODER_TARGET_MODULES,
deprecate,
is_safetensors_available,
logging,
)


if is_safetensors_available():
Expand All @@ -32,6 +40,9 @@
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"

TEXT_ENCODER_LORA_WEIGHT_NAME = "pytorch_text_encoder_lora_weights.bin"
TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE = "pytorch_text_encoder_lora_weights.safetensors"


class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
Expand Down Expand Up @@ -294,3 +305,307 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")


class TextEncoderLoRAMixin:
r"""
This class is used for handling the text encoder used in our pipelines with LoRA. The methods of this class are
mostly copy-pasted from [`~UNet2DConditionLoadersMixin`]. We couldn't fully reuse the class because we cannot do
things like `self.set_attn_processor()`.

Args:
text_encoder (`nn.Module`):
The text encoder module underlying a [`~DiffusionPipeline`].
"""

def __init__(self, text_encoder: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the __init__ from the Mixin? I think we could call _initialize_lora_layers() in load_attn_procs no?

I'm not a fan of Mixins having inits because this means they can't be "plugged" into the StableDiffusionPipeline class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, mixins should ideally not have __init__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't envision this Mixin to be used with a DiffusionPipeline class.

_initialize_lora_layers() initializes the LoRA parameters, and having it inside load_attn_procs() is not a good choice IMO since syntactically both of them are doing different things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer: I'm here thinking more about inference of text encoder loras then training

I think we should have a function called:

load_lora()

or:

load_lora_weights(...)

That can be called from StableDiffusionPipeline(...)

I don't think we should wrap the text encoder:

text_encoder_lora_wrapper = TextEncoderLoRAMixin(text_encoder)

=> this breaks things for inference
text_encoder_lora_wrapper cannot be passed to the StableDiffusionPipeline because it doesn't have a forward method, it cannot be saved etc...

self.text_encoder = text_encoder
self.device = text_encoder.device
self.dtype = text_encoder.dtype
self._initialize_lora_layers()

def _initialize_lora_layers(self):
self.lora_attn_procs = {}
for name, module in self.text_encoder.named_modules():
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
self.lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
)

self.text_encoder_lora_layers = AttnProcsLayers(self.lora_attn_procs)

def load_attn_procs(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
) -> nn.Module:
r"""
Load pretrained attention processor layers into
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Instead
of setting the attention processing layers (as done in [`~UNet2DConditionLoadersMixin.load_attn_procs`]), we
use the LoRA attention layers to monkey-patch the forward passes of the attention modules of the
`text_encoder`.

<Tip warning={true}>

This function is experimental and might change in the future.

</Tip>

Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:

- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
`./my_model_directory/`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.

mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.

Returns:
`nn.Module`: The text encoder module with the forward passes of its attention modules monkey-patched.

<Tip>

It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).

</Tip>

<Tip>

Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.

</Tip>
"""

cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except IOError as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or TEXT_ENCODER_LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict

# fill attn processors
attn_processors = {}

is_lora = all("lora" in k for k in state_dict.keys())

if is_lora:
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value

for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)

else:
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")

# set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}

return self._modify_text_encoder(attn_processors)

def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]) -> nn.Module:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.

Args:
attn_processors: Dict[str, `LoRAAttnProcessor`]:
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
Returns:
`nn.Module`: The modified text encoder.
"""
# Loop over the original attention modules.
for name, _ in self.text_encoder.named_modules():
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
# Retrieve the module and its corresponding LoRA processor.
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)

# Monkey-patch.
module.forward = new_forward
return self.text_encoder

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
return "to_q_lora"
elif "v_proj" in name:
return "to_v_lora"
elif "k_proj" in name:
return "to_k_lora"
else:
return "to_out_lora"

def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
text_encoder_lora_layers: nn.Module,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
**kwargs,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
[`~loaders.TextEncoderLoRAMixin.load_attn_procs`] method.

Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
text_encoder_lora_layers (`nn.Module`):
LoRA trainable parameters provided as `nn.Module`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
weight_name = weight_name or deprecate(
"weights_name",
"0.18.0",
"`weights_name` is deprecated, please use `weight_name` instead.",
take_from=kwargs,
)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if save_function is None:
if safe_serialization:

def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})

else:
save_function = torch.save

os.makedirs(save_directory, exist_ok=True)

model_to_save = text_encoder_lora_layers

# Save the model
state_dict = model_to_save.state_dict()

if weight_name is None:
if safe_serialization:
weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME_SAFE
else:
weight_name = TEXT_ENCODER_LORA_WEIGHT_NAME

# Save the model
save_function(state_dict, os.path.join(save_directory, weight_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
Loading