From 66a775ac2e3f7e2389e2fa772d97b63cd4724151 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Tue, 18 Jul 2023 20:44:47 +0300 Subject: [PATCH 1/2] Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori Co-Authored-By: Sayak Paul --- docs/source/en/training/lora.mdx | 5 +- src/diffusers/loaders.py | 118 ++++++++++++++++++-- src/diffusers/models/attention.py | 5 +- src/diffusers/models/attention_processor.py | 31 +---- src/diffusers/models/lora.py | 115 +++++++++++++++++++ src/diffusers/models/transformer_2d.py | 5 +- tests/models/test_layers_utils.py | 5 +- tests/models/test_lora_layers.py | 50 ++++++++- 8 files changed, 288 insertions(+), 46 deletions(-) create mode 100644 src/diffusers/models/lora.py diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index f7cfa5a8ea77..670a94658160 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License. -Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also -support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage. +This is an experimental feature. Its APIs can change in future. @@ -286,6 +285,8 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pip ## Supporting A1111 themed LoRA checkpoints from Diffusers +This support was made possible because of our amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical). + To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity. In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 246765f76037..da91af85466f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -25,6 +25,7 @@ from huggingface_hub import hf_hub_download from torch import nn +from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, @@ -56,6 +57,7 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +TOTAL_EXAMPLE_KEYS = 5 TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" @@ -105,6 +107,22 @@ def text_encoder_attn_modules(text_encoder): return attn_modules +def text_encoder_mlp_modules(text_encoder): + mlp_modules = [] + + if isinstance(text_encoder, CLIPTextModel): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + mlp_mod = layer.mlp + name = f"text_model.encoder.layers.{i}.mlp" + mlp_modules.append((name, mlp_mod)) + elif isinstance(text_encoder, CLIPTextModelWithProjection): + pass # SDXL is not supported yet. + else: + raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") + + return mlp_modules + + def text_encoder_lora_state_dict(text_encoder): state_dict = {} @@ -304,6 +322,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # fill attn processors attn_processors = {} + non_attn_lora_layers = [] is_lora = all("lora" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) @@ -327,13 +346,33 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict lora_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.down.weight"].shape[0] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - attn_processor = self for sub_key in key.split("."): attn_processor = getattr(attn_processor, sub_key) + # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers + # or add_{k,v,q,out_proj}_proj_lora layers. + if "lora.down.weight" in value_dict: + rank = value_dict["lora.down.weight"].shape[0] + hidden_size = value_dict["lora.up.weight"].shape[0] + + if isinstance(attn_processor, LoRACompatibleConv): + lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) + elif isinstance(attn_processor, LoRACompatibleLinear): + lora = LoRALinearLayer( + attn_processor.in_features, attn_processor.out_features, rank, network_alpha + ) + else: + raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") + + value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} + lora.load_state_dict(value_dict) + non_attn_lora_layers.append((attn_processor, lora)) + continue + + rank = value_dict["to_k_lora.down.weight"].shape[0] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + if isinstance( attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) ): @@ -390,10 +429,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # set correct dtype & device attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} + non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers] # set layers self.set_attn_processor(attn_processors) + # set ff layers + for target_module, lora_layer in non_attn_lora_layers: + if hasattr(target_module, "set_lora_layer"): + target_module.set_lora_layer(lora_layer) + def save_attn_procs( self, save_directory: Union[str, os.PathLike], @@ -840,7 +885,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) self.load_lora_into_text_encoder( - state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale + state_dict, + network_alpha=network_alpha, + text_encoder=self.text_encoder, + lora_scale=self.lora_scale, ) @classmethod @@ -1049,6 +1097,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } + if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -1092,8 +1141,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr rank = text_encoder_lora_state_dict[ "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" ].shape[1] + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) + cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp) # set correct dtype & device text_encoder_lora_state_dict = { @@ -1125,8 +1175,21 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): attn_module.v_proj = attn_module.v_proj.regular_linear_layer attn_module.out_proj = attn_module.out_proj.regular_linear_layer + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1 = mlp_module.fc1.regular_linear_layer + mlp_module.fc2 = mlp_module.fc2.regular_linear_layer + @classmethod - def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None): + def _modify_text_encoder( + cls, + text_encoder, + lora_scale=1, + network_alpha=None, + rank=4, + dtype=None, + patch_mlp=False, + ): r""" Monkey-patches the forward passes of attention modules of the text encoder. """ @@ -1157,6 +1220,18 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra ) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) + if patch_mlp: + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + mlp_module.fc1 = PatchedLoraProjection( + mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) + + mlp_module.fc2 = PatchedLoraProjection( + mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype + ) + lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) + return lora_parameters @classmethod @@ -1261,9 +1336,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict = {} te_state_dict = {} network_alpha = None + unloaded_keys = [] for key, value in state_dict.items(): - if "lora_down" in key: + if "hada" in key or "skip" in key: + unloaded_keys.append(key) + elif "lora_down" in key: lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" @@ -1284,12 +1362,21 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") unet_state_dict[diffusers_name] = value unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif "ff" in diffusers_name: + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") @@ -1301,6 +1388,19 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict): if "self_attn" in diffusers_name: te_state_dict[diffusers_name] = value te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = value + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + + logger.info("Kohya-style checkpoint detected.") + if len(unloaded_keys) > 0: + example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) + logger.warning( + f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." + ) unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} @@ -1346,6 +1446,10 @@ def unload_lora_weights(self): [attention_proc_class] = unet_attention_classes self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]()) + for _, module in self.unet.named_modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6b05bf35e87f..ad899212e5a5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,6 +21,7 @@ from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings +from .lora import LoRACompatibleLinear @maybe_allow_in_graph @@ -245,7 +246,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(nn.Linear(inner_dim, dim_out)) + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) @@ -289,7 +290,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8468c864fe4d..de4adec042f6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -19,6 +19,7 @@ from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -505,36 +506,6 @@ def __call__( return hidden_states -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - - self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) - self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) - # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. - # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning - self.network_alpha = network_alpha - self.rank = rank - - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - if self.network_alpha is not None: - up_hidden_states *= self.network_alpha / self.rank - - return up_hidden_states.to(orig_dtype) - - class LoRAAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py new file mode 100644 index 000000000000..bb8389745776 --- /dev/null +++ b/src/diffusers/models/lora.py @@ -0,0 +1,115 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from torch import nn + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRAConv2dLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): + super().__init__() + + if rank > min(in_features, out_features): + raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") + + self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) + self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRACompatibleConv(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + self.lora_layer = lora_layer + + def forward(self, x): + if self.lora_layer is None: + return super().forward(x) + else: + return super().forward(x) + self.lora_layer(x) + + +class LoRACompatibleLinear(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + self.lora_layer = lora_layer + + def forward(self, x): + if self.lora_layer is None: + return super().forward(x) + else: + return super().forward(x) + self.lora_layer(x) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 83da16838ae2..bbd93430da14 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,6 +23,7 @@ from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed +from .lora import LoRACompatibleConv from .modeling_utils import ModelMixin @@ -138,7 +139,7 @@ def __init__( if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -194,7 +195,7 @@ def __init__( if use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index b438b2ddb4af..40627cc93caa 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -22,6 +22,7 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.embeddings import get_timestep_embedding +from diffusers.models.lora import LoRACompatibleLinear from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel from diffusers.utils import torch_device @@ -482,7 +483,7 @@ def test_spatial_transformer_default_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear dim = 32 inner_dim = 128 @@ -506,7 +507,7 @@ def test_spatial_transformer_geglu_approx_ff_layers(self): assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout - assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear + assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear dim = 32 inner_dim = 128 diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 3f563535eb9f..8c751bc6bf07 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -738,7 +738,7 @@ def test_a1111(self): images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245]) + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) self.assertTrue(np.allclose(images, expected, atol=1e-4)) @@ -778,6 +778,7 @@ def test_unload_lora(self): lora_filename = "Colored_Icons_by_vizsumit.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) lora_images = pipe( prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps ).images @@ -792,3 +793,50 @@ def test_unload_lora(self): self.assertFalse(np.allclose(initial_images, lora_images)) self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + + def test_load_unload_load_kohya_lora(self): + # This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded + # without introducing any side-effects. Even though the test uses a Kohya-style + # LoRA, the underlying adapter handling mechanism is format-agnostic. + generator = torch.manual_seed(0) + prompt = "masterpiece, best quality, mountain" + num_inference_steps = 2 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + initial_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + initial_images = initial_images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" + lora_filename = "Colored_Icons_by_vizsumit.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images = lora_images[0, -3:, -3:, -1].flatten() + + pipe.unload_lora_weights() + generator = torch.manual_seed(0) + unloaded_lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(initial_images, lora_images)) + self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + + # make sure we can load a LoRA again after unloading and they don't have + # any undesired effects. + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images_again = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) From db46e2b84c43c21a0edce0f0e31e4b839e6465db Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Tue, 25 Jul 2023 17:07:58 +0300 Subject: [PATCH 2/2] tmp: add sdxl to mlp_modules --- src/diffusers/loaders.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index da91af85466f..84454aa4772d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -110,13 +110,11 @@ def text_encoder_attn_modules(text_encoder): def text_encoder_mlp_modules(text_encoder): mlp_modules = [] - if isinstance(text_encoder, CLIPTextModel): + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): for i, layer in enumerate(text_encoder.text_model.encoder.layers): mlp_mod = layer.mlp name = f"text_model.encoder.layers.{i}.mlp" mlp_modules.append((name, mlp_mod)) - elif isinstance(text_encoder, CLIPTextModelWithProjection): - pass # SDXL is not supported yet. else: raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")