diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 1d55bd03c064..8cbf21b8e0cf 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -28,3 +28,11 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ### UNet2DConditionLoadersMixin [[autodoc]] loaders.UNet2DConditionLoadersMixin + +### TextualInversionLoaderMixin + +[[autodoc]] loaders.TextualInversionLoaderMixin + +### LoraLoaderMixin + +[[autodoc]] loaders.LoraLoaderMixin diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a262833938e7..31939ca4b481 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -21,6 +21,7 @@ from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, + TEXT_ENCODER_TARGET_MODULES, _get_model_file, deprecate, is_safetensors_available, @@ -81,12 +82,12 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict r""" Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be defined in - [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) + [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) and be a `torch.nn.Module` class. - This function is experimental and might change in the future. + This function is experimental and might change in the future. @@ -125,7 +126,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. - mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. @@ -133,8 +133,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). """ @@ -250,7 +250,7 @@ def save_attn_procs( ): r""" Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. + [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): @@ -372,12 +372,12 @@ def load_textual_inversion( - This function is experimental and might change in the future. + This function is experimental and might change in the future. Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): + pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. @@ -566,4 +566,452 @@ def load_textual_inversion( for token_id, embedding in zip(token_ids, embeddings): self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - logger.info("Loaded textual inversion embedding for {token}.") + logger.info(f"Loaded textual inversion embedding for {token}.") + + +class LoraLoaderMixin: + r""" + Utility class for handling the loading LoRA layers into UNet (of class [`UNet2DConditionModel`]) and Text Encoder + (of class [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). + + + + This function is experimental and might change in the future. + + + """ + text_encoder_name = "text_encoder" + unet_name = "unet" + + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers (such as LoRA) into [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + + # Load the layers corresponding to UNet. + if all(key.startswith(self.unet_name) for key in keys): + logger.info(f"Loading {self.unet_name}.") + unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} + self.unet.load_attn_procs(unet_lora_state_dict) + + # Load the layers corresponding to text encoder and make necessary adjustments. + elif all(key.startswith(self.text_encoder_name) for key in keys): + logger.info(f"Loading {self.text_encoder_name}.") + text_encoder_lora_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) + } + attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) + self._modify_text_encoder(attn_procs_text_encoder) + + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any prefix. + elif not all( + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + ): + self.unet.load_attn_procs(state_dict) + deprecation_message = "You have saved the LoRA weights using the old format. This will be" + " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" + " in a dictionary and then create a new dictionary like the following:" + " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." + deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) + + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + + Parameters: + attn_processors: Dict[str, `LoRAAttnProcessor`]: + A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. + """ + # Loop over the original attention modules. + for name, _ in self.text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + # Retrieve the module and its corresponding LoRA processor. + module = self.text_encoder.get_submodule(name) + # Construct a new function that performs the LoRA merging. We will monkey patch + # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward + + def new_forward(x): + return old_forward(x) + lora_layer(x) + + # Monkey-patch. + module.forward = new_forward + + def _get_lora_layer_attribute(self, name: str) -> str: + if "q_proj" in name: + return "to_q_lora" + elif "v_proj" in name: + return "to_v_lora" + elif "k_proj" in name: + return "to_k_lora" + else: + return "to_out_lora" + + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + r""" + Load pretrained attention processor layers for + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + Returns: + `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding + [`LoRAAttnProcessor`]. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + """ + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + if use_safetensors and not is_safetensors_available(): + raise ValueError( + "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" + ) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = is_safetensors_available() + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except IOError as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + # fill attn processors + attn_processors = {} + + is_lora = all("lora" in k for k in state_dict.keys()) + + if is_lora: + lora_grouped_dict = defaultdict(dict) + for key, value in state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + else: + raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") + + # set correct dtype & device + attn_processors = { + k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() + } + return attn_processors + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, torch.nn.Module] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = False, + ): + r""" + Save the LoRA parameters corresponding to the UNet and the text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module`]): + State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the + serialization process easier and cleaner. + text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): + State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from + `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state + dict. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + + else: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + # Create a flat dictionary. + state_dict = {} + if unet_lora_layers is not None: + unet_lora_state_dict = { + f"{self.unet_name}.{module_name}": param + for module_name, param in unet_lora_layers.state_dict().items() + } + state_dict.update(unet_lora_state_dict) + if text_encoder_lora_layers is not None: + text_encoder_lora_state_dict = { + f"{self.text_encoder_name}.{module_name}": param + for module_name, param in text_encoder_lora_layers.state_dict().items() + } + state_dict.update(text_encoder_lora_state_dict) + + # Save the model + if weight_name is None: + if safe_serialization: + weight_name = LORA_WEIGHT_NAME_SAFE + else: + weight_name = LORA_WEIGHT_NAME + + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index fcf44f02c731..689febe3e891 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -53,7 +53,7 @@ """ -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3a1103ac1adf..bb159d9db375 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) from .deprecation_utils import deprecate diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9e60a2a873b..1134ba6fb656 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,3 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] diff --git a/tests/test_lora_layers.py b/tests/test_lora_layers.py new file mode 100644 index 000000000000..9bcdc5d93301 --- /dev/null +++ b/tests/test_lora_layers.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +import torch +import torch.nn as nn +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device + + +def create_unet_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): + text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + return text_encoder_lora_layers + + +class LoraLoaderMixinTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "text_encoder_lora_layers": text_encoder_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + return noise, input_ids, pipeline_inputs + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_save_load_safetensors(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_save_load_legacy(self): + pipeline_components, lora_components = self.get_dummy_components() + unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + unet = sd_pipe.unet + unet.set_attn_processor(unet_lora_attn_procs) + unet.save_attn_procs(tmpdirname) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))