diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index 68e21f918173..1bfb3f5c48b7 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -133,6 +133,62 @@ image ![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png) +### Customize adapters strength +For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`]. + +For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts: +```python +pipe.enable_lora() # enable lora again, after we disabled it above +prompt = "toy_face of a hacker with a hoodie, pixel art" +adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} } +pipe.set_adapters("pixel", adapter_weight_scales) +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] +image +``` + +![block-lora-text-and-down](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_down.png) + +Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image. +```python +adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} } +pipe.set_adapters("pixel", adapter_weight_scales) +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] +image +``` + +![block-lora-text-and-mid](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mid.png) + +```python +adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} } +pipe.set_adapters("pixel", adapter_weight_scales) +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] +image +``` + +![block-lora-text-and-up](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_up.png) + +Looks cool! + +This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters. +```python +adapter_weight_scales_toy = 0.5 +adapter_weight_scales_pixel = { + "unet": { + "down": 0.9, # all transformers in the down-part will use scale 0.9 + # "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0 + "up": { + "block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6 + "block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively + } + } +} +pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel]) +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] +image +``` + +![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png) + ## Manage active adapters You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters: diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index b59b46aeba51..b079d2165ece 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -153,18 +153,43 @@ image - - -For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA. - - - To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights: ```py pipeline.unload_lora_weights() ``` +### Adjust LoRA weight scale + +For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA. + +For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by. +```python +pipe = ... # create pipeline +pipe.load_lora_weights(..., adapter_name="my_adapter") +scales = { + "text_encoder": 0.5, + "text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder + "unet": { + "down": 0.9, # all transformers in the down-part will use scale 0.9 + # "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0 + "up": { + "block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6 + "block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively + } + } +} +pipe.set_adapters("my_adapter", scales) +``` + +This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it. + + + +Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0. + + + ### Kohya and TheLastBen Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way. diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index aa53563fd39d..6ef6682b0e17 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os from pathlib import Path @@ -985,7 +986,7 @@ def set_adapters_for_text_encoder( self, adapter_names: Union[List[str], str], text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 - text_encoder_weights: List[float] = None, + text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, ): """ Sets the adapter layers for the text encoder. @@ -1003,15 +1004,20 @@ def set_adapters_for_text_encoder( raise ValueError("PEFT backend is required for this method.") def process_weights(adapter_names, weights): - if weights is None: - weights = [1.0] * len(adapter_names) - elif isinstance(weights, float): - weights = [weights] + # Expand weights into a list, one entry per adapter + # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): raise ValueError( f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" ) + + # Set None values to default of 1.0 + # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] + weights = [w if w is not None else 1.0 for w in weights] + return weights adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names @@ -1059,17 +1065,77 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] def set_adapters( self, adapter_names: Union[List[str], str], - adapter_weights: Optional[List[float]] = None, + adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + adapter_weights = copy.deepcopy(adapter_weights) + + # Expand weights into a list, one entry per adapter + if not isinstance(adapter_weights, list): + adapter_weights = [adapter_weights] * len(adapter_names) + + if len(adapter_names) != len(adapter_weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" + ) + + # Decompose weights into weights for unet, text_encoder and text_encoder_2 + unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], [] + + list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} + all_adapters = { + adapter for adapters in list_adapters.values() for adapter in adapters + } # eg ["adapter1", "adapter2"] + invert_list_adapters = { + adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] + for adapter in all_adapters + } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} + + for adapter_name, weights in zip(adapter_names, adapter_weights): + if isinstance(weights, dict): + unet_lora_weight = weights.pop("unet", None) + text_encoder_lora_weight = weights.pop("text_encoder", None) + text_encoder_2_lora_weight = weights.pop("text_encoder_2", None) + + if len(weights) > 0: + raise ValueError( + f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}." + ) + + if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"): + logger.warning( + "Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2." + ) + + # warn if adapter doesn't have parts specified by adapter_weights + for part_weight, part_name in zip( + [unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight], + ["uent", "text_encoder", "text_encoder_2"], + ): + if part_weight is not None and part_name not in invert_list_adapters[adapter_name]: + logger.warning( + f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." + ) + + else: + unet_lora_weight = weights + text_encoder_lora_weight = weights + text_encoder_2_lora_weight = weights + + unet_lora_weights.append(unet_lora_weight) + text_encoder_lora_weights.append(text_encoder_lora_weight) + text_encoder_2_lora_weights.append(text_encoder_2_lora_weight) + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet # Handle the UNET - unet.set_adapters(adapter_names, adapter_weights) + unet.set_adapters(adapter_names, unet_lora_weights) # Handle the Text Encoder if hasattr(self, "text_encoder"): - self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights) + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights) if hasattr(self, "text_encoder_2"): - self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights) + self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights) def disable_lora(self): if not USE_PEFT_BACKEND: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c3449d53fcb7..8bbec26189b0 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -47,6 +47,7 @@ infer_stable_cascade_single_file_config, load_single_file_model_checkpoint, ) +from .unet_loader_utils import _maybe_expand_lora_scales from .utils import AttnProcsLayers @@ -564,7 +565,7 @@ def _unfuse_lora_apply(self, module): def set_adapters( self, adapter_names: Union[List[str], str], - weights: Optional[Union[List[float], float]] = None, + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ Set the currently active adapters for use in the UNet. @@ -597,9 +598,9 @@ def set_adapters( adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names - if weights is None: - weights = [1.0] * len(adapter_names) - elif isinstance(weights, float): + # Expand weights into a list, one entry per adapter + # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): @@ -607,6 +608,13 @@ def set_adapters( f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." ) + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + weights = _maybe_expand_lora_scales(self, weights) + set_weights_and_activate_adapters(self, adapter_names, weights) def disable_lora(self): diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py new file mode 100644 index 000000000000..918a0fca06c8 --- /dev/null +++ b/src/diffusers/loaders/unet_loader_utils.py @@ -0,0 +1,154 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import TYPE_CHECKING, Dict, List, Union + +from ..utils import logging + + +if TYPE_CHECKING: + # import here to avoid circular imports + from ..models import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _translate_into_actual_layer_name(name): + """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" + if name == "mid": + return "mid_block.attentions.0" + + updown, block, attn = name.split(".") + + updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") + block = block.replace("block_", "") + attn = "attentions." + attn + + return ".".join((updown, block, attn)) + + +def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]): + blocks_with_transformer = { + "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], + "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], + } + transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} + + expanded_weight_scales = [ + _maybe_expand_lora_scales_for_one_adapter( + weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict() + ) + for weight_for_adapter in weight_scales + ] + + return expanded_weight_scales + + +def _maybe_expand_lora_scales_for_one_adapter( + scales: Union[float, Dict], + blocks_with_transformer: Dict[str, int], + transformer_per_block: Dict[str, int], + state_dict: None, +): + """ + Expands the inputs into a more granular dictionary. See the example below for more details. + + Parameters: + scales (`Union[float, Dict]`): + Scales dict to expand. + blocks_with_transformer (`Dict[str, int]`): + Dict with keys 'up' and 'down', showing which blocks have transformer layers + transformer_per_block (`Dict[str, int]`): + Dict with keys 'up' and 'down', showing how many transformer layers each block has + + E.g. turns + ```python + scales = { + 'down': 2, + 'mid': 3, + 'up': { + 'block_0': 4, + 'block_1': [5, 6, 7] + } + } + blocks_with_transformer = { + 'down': [1,2], + 'up': [0,1] + } + transformer_per_block = { + 'down': 2, + 'up': 3 + } + ``` + into + ```python + { + 'down.block_1.0': 2, + 'down.block_1.1': 2, + 'down.block_2.0': 2, + 'down.block_2.1': 2, + 'mid': 3, + 'up.block_0.0': 4, + 'up.block_0.1': 4, + 'up.block_0.2': 4, + 'up.block_1.0': 5, + 'up.block_1.1': 6, + 'up.block_1.2': 7, + } + ``` + """ + if sorted(blocks_with_transformer.keys()) != ["down", "up"]: + raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") + + if sorted(transformer_per_block.keys()) != ["down", "up"]: + raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") + + if not isinstance(scales, dict): + # don't expand if scales is a single number + return scales + + scales = copy.deepcopy(scales) + + if "mid" not in scales: + scales["mid"] = 1 + + for updown in ["up", "down"]: + if updown not in scales: + scales[updown] = 1 + + # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} + if not isinstance(scales[updown], dict): + scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]} + + # eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}} + for i in blocks_with_transformer[updown]: + block = f"block_{i}" + if not isinstance(scales[updown][block], list): + scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] + + # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} + for i in blocks_with_transformer[updown]: + block = f"block_{i}" + for tf_idx, value in enumerate(scales[updown][block]): + scales[f"{updown}.{block}.{tf_idx}"] = value + + del scales[updown] + + for layer in scales.keys(): + if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): + raise ValueError( + f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." + ) + + return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 718e0b46d87c..feececc56966 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -230,16 +230,26 @@ def delete_adapter_layers(model, adapter_name): def set_weights_and_activate_adapters(model, adapter_names, weights): from peft.tuners.tuners_utils import BaseTunerLayer + def get_module_weight(weight_for_adapter, module_name): + if not isinstance(weight_for_adapter, dict): + # If weight_for_adapter is a single number, always return it. + return weight_for_adapter + + for layer_name, weight_ in weight_for_adapter.items(): + if layer_name in module_name: + return weight_ + raise RuntimeError(f"No LoRA weight found for module {module_name}.") + # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights): - for module in model.modules(): + for module_name, module in model.named_modules(): if isinstance(module, BaseTunerLayer): # For backward compatbility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_name) else: module.active_adapter = adapter_name - module.set_scale(adapter_name, weight) + module.set_scale(adapter_name, get_module_weight(weight, module_name)) # set multiple active adapters for module in model.modules(): diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7d84ac024dee..9aed5defada2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -15,6 +15,7 @@ import os import tempfile import unittest +from itertools import product import numpy as np import torch @@ -762,6 +763,218 @@ def test_simple_inference_with_text_unet_multi_adapter(self): "output with no lora and output with lora disabled should give same results", ) + def test_simple_inference_with_text_unet_block_scale(self): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + one adapter and set differnt weights for different blocks (i.e. block lora) + """ + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + weights_1 = {"text_encoder": 2, "unet": {"down": 5}} + pipe.set_adapters("adapter-1", weights_1) + output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + + weights_2 = {"unet": {"up": 5}} + pipe.set_adapters("adapter-1", weights_2) + output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertFalse( + np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), + "LoRA weights 1 and 2 should give different results", + ) + self.assertFalse( + np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3), + "No adapter and LoRA weights 1 should give different results", + ) + self.assertFalse( + np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3), + "No adapter and LoRA weights 2 should give different results", + ) + + pipe.disable_lora() + output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + def test_simple_inference_with_text_unet_multi_adapter_block_lora(self): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set differnt weights for different blocks (i.e. block lora) + """ + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + scales_1 = {"text_encoder": 2, "unet": {"down": 5}} + scales_2 = {"unet": {"down": 5, "mid": 5}} + pipe.set_adapters("adapter-1", scales_1) + + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters("adapter-2", scales_2) + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) + + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + + # Fuse and unfuse should lead to the same results + self.assertFalse( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + self.assertFalse( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + self.assertFalse( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.disable_lora() + + output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + # a mismatching number of adapter_names and adapter_weights should raise an error + with self.assertRaises(ValueError): + pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) + + def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self): + """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" + + def updown_options(blocks_with_tf, layers_per_block, value): + """ + Generate every possible combination for how a lora weight dict for the up/down part can be. + E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ... + """ + num_val = value + list_val = [value] * layers_per_block + + node_opts = [None, num_val, list_val] + node_opts_foreach_block = [node_opts] * len(blocks_with_tf) + + updown_opts = [num_val] + for nodes in product(*node_opts_foreach_block): + if all(n is None for n in nodes): + continue + opt = {} + for b, n in zip(blocks_with_tf, nodes): + if n is not None: + opt["block_" + str(b)] = n + updown_opts.append(opt) + return updown_opts + + def all_possible_dict_opts(unet, value): + """ + Generate every possible combination for how a lora weight dict can be. + E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ... + """ + + down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")] + up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")] + + layers_per_block = unet.config.layers_per_block + + text_encoder_opts = [None, value] + text_encoder_2_opts = [None, value] + mid_opts = [None, value] + down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value) + up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value) + + opts = [] + + for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts): + if all(o is None for o in (t1, t2, d, m, u)): + continue + opt = {} + if t1 is not None: + opt["text_encoder"] = t1 + if t2 is not None: + opt["text_encoder_2"] = t2 + if all(o is None for o in (d, m, u)): + # no unet scaling + continue + opt["unet"] = {} + if d is not None: + opt["unet"]["down"] = d + if m is not None: + opt["unet"]["mid"] = m + if u is not None: + opt["unet"]["up"] = u + opts.append(opt) + + return opts + + components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + + for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): + # test if lora block scales can be set with this scale_dict + if not self.has_two_text_encoders and "text_encoder_2" in scale_dict: + del scale_dict["text_encoder_2"] + + pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error + def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches