diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index 68e21f918173..1bfb3f5c48b7 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -133,6 +133,62 @@ image

+### Customize adapters strength
+For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
+
+For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
+```python
+pipe.enable_lora() # enable lora again, after we disabled it above
+prompt = "toy_face of a hacker with a hoodie, pixel art"
+adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
+pipe.set_adapters("pixel", adapter_weight_scales)
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+image
+```
+
+
+
+Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
+```python
+adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
+pipe.set_adapters("pixel", adapter_weight_scales)
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+image
+```
+
+
+
+```python
+adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
+pipe.set_adapters("pixel", adapter_weight_scales)
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+image
+```
+
+
+
+Looks cool!
+
+This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
+```python
+adapter_weight_scales_toy = 0.5
+adapter_weight_scales_pixel = {
+ "unet": {
+ "down": 0.9, # all transformers in the down-part will use scale 0.9
+ # "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
+ "up": {
+ "block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
+ "block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
+ }
+ }
+}
+pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+image
+```
+
+
+
## Manage active adapters
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md
index b59b46aeba51..b079d2165ece 100644
--- a/docs/source/en/using-diffusers/loading_adapters.md
+++ b/docs/source/en/using-diffusers/loading_adapters.md
@@ -153,18 +153,43 @@ image
-
-
-For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
-
-
-
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
```py
pipeline.unload_lora_weights()
```
+### Adjust LoRA weight scale
+
+For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
+
+For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
+```python
+pipe = ... # create pipeline
+pipe.load_lora_weights(..., adapter_name="my_adapter")
+scales = {
+ "text_encoder": 0.5,
+ "text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
+ "unet": {
+ "down": 0.9, # all transformers in the down-part will use scale 0.9
+ # "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
+ "up": {
+ "block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
+ "block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
+ }
+ }
+}
+pipe.set_adapters("my_adapter", scales)
+```
+
+This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.
+
+
+
+Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
+
+
+
### Kohya and TheLastBen
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index aa53563fd39d..6ef6682b0e17 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import os
from pathlib import Path
@@ -985,7 +986,7 @@ def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
- text_encoder_weights: List[float] = None,
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
@@ -1003,15 +1004,20 @@ def set_adapters_for_text_encoder(
raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights):
- if weights is None:
- weights = [1.0] * len(adapter_names)
- elif isinstance(weights, float):
- weights = [weights]
+ # Expand weights into a list, one entry per adapter
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
+ if not isinstance(weights, list):
+ weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
+
+ # Set None values to default of 1.0
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
+ weights = [w if w is not None else 1.0 for w in weights]
+
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
@@ -1059,17 +1065,77 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
def set_adapters(
self,
adapter_names: Union[List[str], str],
- adapter_weights: Optional[List[float]] = None,
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+
+ adapter_weights = copy.deepcopy(adapter_weights)
+
+ # Expand weights into a list, one entry per adapter
+ if not isinstance(adapter_weights, list):
+ adapter_weights = [adapter_weights] * len(adapter_names)
+
+ if len(adapter_names) != len(adapter_weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
+ )
+
+ # Decompose weights into weights for unet, text_encoder and text_encoder_2
+ unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
+
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
+ all_adapters = {
+ adapter for adapters in list_adapters.values() for adapter in adapters
+ } # eg ["adapter1", "adapter2"]
+ invert_list_adapters = {
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
+ for adapter in all_adapters
+ } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
+
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
+ if isinstance(weights, dict):
+ unet_lora_weight = weights.pop("unet", None)
+ text_encoder_lora_weight = weights.pop("text_encoder", None)
+ text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
+
+ if len(weights) > 0:
+ raise ValueError(
+ f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
+ )
+
+ if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
+ logger.warning(
+ "Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
+ )
+
+ # warn if adapter doesn't have parts specified by adapter_weights
+ for part_weight, part_name in zip(
+ [unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
+ ["uent", "text_encoder", "text_encoder_2"],
+ ):
+ if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
+ logger.warning(
+ f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
+ )
+
+ else:
+ unet_lora_weight = weights
+ text_encoder_lora_weight = weights
+ text_encoder_2_lora_weight = weights
+
+ unet_lora_weights.append(unet_lora_weight)
+ text_encoder_lora_weights.append(text_encoder_lora_weight)
+ text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
- unet.set_adapters(adapter_names, adapter_weights)
+ unet.set_adapters(adapter_names, unet_lora_weights)
# Handle the Text Encoder
if hasattr(self, "text_encoder"):
- self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
+ self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
if hasattr(self, "text_encoder_2"):
- self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
+ self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
def disable_lora(self):
if not USE_PEFT_BACKEND:
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index c3449d53fcb7..8bbec26189b0 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -47,6 +47,7 @@
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
+from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers
@@ -564,7 +565,7 @@ def _unfuse_lora_apply(self, module):
def set_adapters(
self,
adapter_names: Union[List[str], str],
- weights: Optional[Union[List[float], float]] = None,
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
@@ -597,9 +598,9 @@ def set_adapters(
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
- if weights is None:
- weights = [1.0] * len(adapter_names)
- elif isinstance(weights, float):
+ # Expand weights into a list, one entry per adapter
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
+ if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
@@ -607,6 +608,13 @@ def set_adapters(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)
+ # Set None values to default of 1.0
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
+ weights = [w if w is not None else 1.0 for w in weights]
+
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
+ weights = _maybe_expand_lora_scales(self, weights)
+
set_weights_and_activate_adapters(self, adapter_names, weights)
def disable_lora(self):
diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py
new file mode 100644
index 000000000000..918a0fca06c8
--- /dev/null
+++ b/src/diffusers/loaders/unet_loader_utils.py
@@ -0,0 +1,154 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+from typing import TYPE_CHECKING, Dict, List, Union
+
+from ..utils import logging
+
+
+if TYPE_CHECKING:
+ # import here to avoid circular imports
+ from ..models import UNet2DConditionModel
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _translate_into_actual_layer_name(name):
+ """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
+ if name == "mid":
+ return "mid_block.attentions.0"
+
+ updown, block, attn = name.split(".")
+
+ updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
+ block = block.replace("block_", "")
+ attn = "attentions." + attn
+
+ return ".".join((updown, block, attn))
+
+
+def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
+ blocks_with_transformer = {
+ "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
+ "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
+ }
+ transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
+
+ expanded_weight_scales = [
+ _maybe_expand_lora_scales_for_one_adapter(
+ weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
+ )
+ for weight_for_adapter in weight_scales
+ ]
+
+ return expanded_weight_scales
+
+
+def _maybe_expand_lora_scales_for_one_adapter(
+ scales: Union[float, Dict],
+ blocks_with_transformer: Dict[str, int],
+ transformer_per_block: Dict[str, int],
+ state_dict: None,
+):
+ """
+ Expands the inputs into a more granular dictionary. See the example below for more details.
+
+ Parameters:
+ scales (`Union[float, Dict]`):
+ Scales dict to expand.
+ blocks_with_transformer (`Dict[str, int]`):
+ Dict with keys 'up' and 'down', showing which blocks have transformer layers
+ transformer_per_block (`Dict[str, int]`):
+ Dict with keys 'up' and 'down', showing how many transformer layers each block has
+
+ E.g. turns
+ ```python
+ scales = {
+ 'down': 2,
+ 'mid': 3,
+ 'up': {
+ 'block_0': 4,
+ 'block_1': [5, 6, 7]
+ }
+ }
+ blocks_with_transformer = {
+ 'down': [1,2],
+ 'up': [0,1]
+ }
+ transformer_per_block = {
+ 'down': 2,
+ 'up': 3
+ }
+ ```
+ into
+ ```python
+ {
+ 'down.block_1.0': 2,
+ 'down.block_1.1': 2,
+ 'down.block_2.0': 2,
+ 'down.block_2.1': 2,
+ 'mid': 3,
+ 'up.block_0.0': 4,
+ 'up.block_0.1': 4,
+ 'up.block_0.2': 4,
+ 'up.block_1.0': 5,
+ 'up.block_1.1': 6,
+ 'up.block_1.2': 7,
+ }
+ ```
+ """
+ if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
+ raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
+
+ if sorted(transformer_per_block.keys()) != ["down", "up"]:
+ raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
+
+ if not isinstance(scales, dict):
+ # don't expand if scales is a single number
+ return scales
+
+ scales = copy.deepcopy(scales)
+
+ if "mid" not in scales:
+ scales["mid"] = 1
+
+ for updown in ["up", "down"]:
+ if updown not in scales:
+ scales[updown] = 1
+
+ # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
+ if not isinstance(scales[updown], dict):
+ scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}
+
+ # eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
+ for i in blocks_with_transformer[updown]:
+ block = f"block_{i}"
+ if not isinstance(scales[updown][block], list):
+ scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
+
+ # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
+ for i in blocks_with_transformer[updown]:
+ block = f"block_{i}"
+ for tf_idx, value in enumerate(scales[updown][block]):
+ scales[f"{updown}.{block}.{tf_idx}"] = value
+
+ del scales[updown]
+
+ for layer in scales.keys():
+ if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
+ raise ValueError(
+ f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
+ )
+
+ return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index 718e0b46d87c..feececc56966 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -230,16 +230,26 @@ def delete_adapter_layers(model, adapter_name):
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer
+ def get_module_weight(weight_for_adapter, module_name):
+ if not isinstance(weight_for_adapter, dict):
+ # If weight_for_adapter is a single number, always return it.
+ return weight_for_adapter
+
+ for layer_name, weight_ in weight_for_adapter.items():
+ if layer_name in module_name:
+ return weight_
+ raise RuntimeError(f"No LoRA weight found for module {module_name}.")
+
# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
- for module in model.modules():
+ for module_name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
- module.set_scale(adapter_name, weight)
+ module.set_scale(adapter_name, get_module_weight(weight, module_name))
# set multiple active adapters
for module in model.modules():
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 7d84ac024dee..9aed5defada2 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -15,6 +15,7 @@
import os
import tempfile
import unittest
+from itertools import product
import numpy as np
import torch
@@ -762,6 +763,218 @@ def test_simple_inference_with_text_unet_multi_adapter(self):
"output with no lora and output with lora disabled should give same results",
)
+ def test_simple_inference_with_text_unet_block_scale(self):
+ """
+ Tests a simple inference with lora attached to text encoder and unet, attaches
+ one adapter and set differnt weights for different blocks (i.e. block lora)
+ """
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
+ pipe.set_adapters("adapter-1", weights_1)
+ output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ weights_2 = {"unet": {"up": 5}}
+ pipe.set_adapters("adapter-1", weights_2)
+ output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertFalse(
+ np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
+ "LoRA weights 1 and 2 should give different results",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
+ "No adapter and LoRA weights 1 should give different results",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
+ "No adapter and LoRA weights 2 should give different results",
+ )
+
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
+
+ def test_simple_inference_with_text_unet_multi_adapter_block_lora(self):
+ """
+ Tests a simple inference with lora attached to text encoder and unet, attaches
+ multiple adapters and set differnt weights for different blocks (i.e. block lora)
+ """
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
+ scales_2 = {"unet": {"down": 5, "mid": 5}}
+ pipe.set_adapters("adapter-1", scales_1)
+
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ pipe.set_adapters("adapter-2", scales_2)
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
+
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
+
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
+
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
+
+ pipe.disable_lora()
+
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
+
+ # a mismatching number of adapter_names and adapter_weights should raise an error
+ with self.assertRaises(ValueError):
+ pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
+
+ def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self):
+ """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
+
+ def updown_options(blocks_with_tf, layers_per_block, value):
+ """
+ Generate every possible combination for how a lora weight dict for the up/down part can be.
+ E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
+ """
+ num_val = value
+ list_val = [value] * layers_per_block
+
+ node_opts = [None, num_val, list_val]
+ node_opts_foreach_block = [node_opts] * len(blocks_with_tf)
+
+ updown_opts = [num_val]
+ for nodes in product(*node_opts_foreach_block):
+ if all(n is None for n in nodes):
+ continue
+ opt = {}
+ for b, n in zip(blocks_with_tf, nodes):
+ if n is not None:
+ opt["block_" + str(b)] = n
+ updown_opts.append(opt)
+ return updown_opts
+
+ def all_possible_dict_opts(unet, value):
+ """
+ Generate every possible combination for how a lora weight dict can be.
+ E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
+ """
+
+ down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
+ up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]
+
+ layers_per_block = unet.config.layers_per_block
+
+ text_encoder_opts = [None, value]
+ text_encoder_2_opts = [None, value]
+ mid_opts = [None, value]
+ down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
+ up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)
+
+ opts = []
+
+ for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
+ if all(o is None for o in (t1, t2, d, m, u)):
+ continue
+ opt = {}
+ if t1 is not None:
+ opt["text_encoder"] = t1
+ if t2 is not None:
+ opt["text_encoder_2"] = t2
+ if all(o is None for o in (d, m, u)):
+ # no unet scaling
+ continue
+ opt["unet"] = {}
+ if d is not None:
+ opt["unet"]["down"] = d
+ if m is not None:
+ opt["unet"]["mid"] = m
+ if u is not None:
+ opt["unet"]["up"] = u
+ opts.append(opt)
+
+ return opts
+
+ components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+
+ for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
+ # test if lora block scales can be set with this scale_dict
+ if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
+ del scale_dict["text_encoder_2"]
+
+ pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
+
def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches