From b209a10a3c01f3409cd1be5ec466be4aad32c107 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 7 Feb 2024 12:48:30 +0530 Subject: [PATCH 1/4] remove torch_dtype from to() --- src/diffusers/pipelines/pipeline_utils.py | 26 ++--------------------- 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 06187645f000..f06633a69068 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -775,32 +775,10 @@ def to(self, *args, **kwargs): Returns: [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ - - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - deprecate("torch_dtype", "0.27.0", "") - torch_device = kwargs.pop("torch_device", None) - if torch_device is not None: - deprecate("torch_device", "0.27.0", "") - - dtype_kwarg = kwargs.pop("dtype", None) - device_kwarg = kwargs.pop("device", None) + dtype = kwargs.pop("dtype", None) + device= kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - if torch_dtype is not None and dtype_kwarg is not None: - raise ValueError( - "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`." - ) - - dtype = torch_dtype or dtype_kwarg - - if torch_device is not None and device_kwarg is not None: - raise ValueError( - "You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`." - ) - - device = torch_device or device_kwarg - dtype_arg = None device_arg = None if len(args) == 1: From 9a41549417cb6b2d9ca710a2c6720fc9cf57d6b8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 7 Feb 2024 12:50:26 +0530 Subject: [PATCH 2/4] remove torch_dtype from usage scripts. --- scripts/convert_gligen_to_diffusers.py | 2 +- scripts/convert_original_stable_diffusion_to_diffusers.py | 2 +- scripts/convert_zero123_to_diffusers.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 6 +++--- tests/pipelines/animatediff/test_animatediff.py | 2 +- tests/pipelines/animatediff/test_animatediff_video2video.py | 2 +- tests/pipelines/audioldm2/test_audioldm2.py | 2 +- tests/pipelines/musicldm/test_musicldm.py | 2 +- tests/pipelines/pia/test_pia.py | 2 +- .../stable_video_diffusion/test_stable_video_diffusion.py | 2 +- tests/pipelines/test_pipelines.py | 2 +- tests/pipelines/test_pipelines_common.py | 2 +- 12 files changed, 14 insertions(+), 14 deletions(-) diff --git a/scripts/convert_gligen_to_diffusers.py b/scripts/convert_gligen_to_diffusers.py index 30d789b60634..83c1f928e407 100644 --- a/scripts/convert_gligen_to_diffusers.py +++ b/scripts/convert_gligen_to_diffusers.py @@ -576,6 +576,6 @@ def convert_gligen_to_diffusers( ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) pipe.save_pretrained(args.dump_path) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 2ca70963d132..980446179cfd 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -179,7 +179,7 @@ ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) if args.controlnet: # only save the controlnet model diff --git a/scripts/convert_zero123_to_diffusers.py b/scripts/convert_zero123_to_diffusers.py index f016312b8bb6..3bb6f6c041c9 100644 --- a/scripts/convert_zero123_to_diffusers.py +++ b/scripts/convert_zero123_to_diffusers.py @@ -801,6 +801,6 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f06633a69068..769fcd2e832a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -776,7 +776,7 @@ def to(self, *args, **kwargs): [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ dtype = kwargs.pop("dtype", None) - device= kwargs.pop("device", None) + device = kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) dtype_arg = None @@ -851,12 +851,12 @@ def module_is_offloaded(module): if is_loaded_in_8bit and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." ) if is_loaded_in_8bit and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: module.to(device, dtype) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 80a8fd19f5a0..525ca24bbd9a 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -218,7 +218,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 3226bdb3ca6e..767fc30b4eb5 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -224,7 +224,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index 60ef86518e35..e2655515bc40 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -483,7 +483,7 @@ def test_to_dtype(self): self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values())) # Once we send to fp16, all params are in half-precision, including the logit scale - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index 4bf03569bbf3..fe78ab6acbb1 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -400,7 +400,7 @@ def test_to_dtype(self): self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values())) # Once we send to fp16, all params are in half-precision, including the logit scale - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index eb76457abc9d..edd129560c63 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -231,7 +231,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 871266fb9c24..60c411283803 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -396,7 +396,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 32ae81ddc2d8..bd9f42f185e6 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1623,7 +1623,7 @@ def test_pipe_to(self): sd1 = sd.to(torch.float16) sd2 = sd.to(None, torch.float16) sd3 = sd.to(dtype=torch.float16) - sd4 = sd.to(torch_dtype=torch.float16) + sd4 = sd.to(dtype=torch.float16) sd5 = sd.to(None, dtype=torch.float16) sd6 = sd.to(None, torch_dtype=torch.float16) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index e3c8a4ef503f..7f51847caf07 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -716,7 +716,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) From adcddf6ba421f847e7da2a0ce57b9456cae43356 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 7 Feb 2024 13:15:52 +0530 Subject: [PATCH 3/4] remove old lora backend --- src/diffusers/loaders/lora.py | 397 +--- tests/lora/test_lora_layers_old_backend.py | 2193 -------------------- 2 files changed, 95 insertions(+), 2495 deletions(-) delete mode 100644 tests/lora/test_lora_layers_old_backend.py diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 922c98b98bf4..6e0e9af51740 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect import os -from contextlib import nullcontext from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -26,7 +25,7 @@ from torch import nn from .. import __version__ -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -51,10 +50,9 @@ if is_transformers_available(): from transformers import PreTrainedModel - from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules if is_accelerate_available(): - from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) @@ -106,6 +104,9 @@ def load_lora_weights( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) @@ -397,6 +398,11 @@ def load_lora_into_unet( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as @@ -427,9 +433,7 @@ def load_lora_into_unet( warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." logger.warn(warn_message) - if USE_PEFT_BACKEND and len(state_dict.keys()) > 0: - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - + if len(state_dict.keys()) > 0: if adapter_name in getattr(unet, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." @@ -518,6 +522,11 @@ def load_lora_into_text_encoder( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -539,34 +548,21 @@ def load_lora_into_text_encoder( rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - if USE_PEFT_BACKEND: - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_B.weight" - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - if patch_mlp: - for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_B.weight" - rank_key_fc2 = f"{name}.fc2.lora_B.weight" - - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] - else: - for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" - rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) - - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - if patch_mlp: - for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" - rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + rank_key = f"{name}.out_proj.lora_B.weight" + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + if patch_mlp: + for name, _ in text_encoder_mlp_modules(text_encoder): + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" + + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] if network_alphas is not None: alpha_keys = [ @@ -576,84 +572,25 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - if USE_PEFT_BACKEND: - from peft import LoraConfig - - lora_config_kwargs = get_peft_kwargs( - rank, network_alphas, text_encoder_lora_state_dict, is_unet=False - ) - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + lora_config = LoraConfig(**lora_config_kwargs) - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - else: - cls._modify_text_encoder( - text_encoder, - lora_scale, - network_alphas, - rank=rank, - patch_mlp=patch_mlp, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - is_pipeline_offloaded = _pipeline is not None and any( - isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") - for c in _pipeline.components.values() - ) - if is_pipeline_offloaded and low_cpu_mem_usage: - low_cpu_mem_usage = True - logger.info( - f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced." - ) + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) - if low_cpu_mem_usage: - device = next(iter(text_encoder_lora_state_dict.values())).device - dtype = next(iter(text_encoder_lora_state_dict.values())).dtype - unexpected_keys = load_model_dict_into_meta( - text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype - ) - else: - load_state_dict_results = text_encoder.load_state_dict( - text_encoder_lora_state_dict, strict=False - ) - unexpected_keys = load_state_dict_results.unexpected_keys + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - if len(unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) - # float: return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 def _remove_text_encoder_monkey_patch(self): - if USE_PEFT_BACKEND: - remove_method = recurse_remove_peft_layers - else: - remove_method = self._remove_text_encoder_monkey_patch_classmethod - if hasattr(self, "text_encoder"): - remove_method(self.text_encoder) - - # In case text encoder have no Lora attached - if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None: + recurse_remove_peft_layers(self.text_encoder) + if getattr(self.text_encoder, "peft_config", None) is not None: del self.text_encoder.peft_config self.text_encoder._hf_peft_config_loaded = None + if hasattr(self, "text_encoder_2"): - remove_method(self.text_encoder_2) - if USE_PEFT_BACKEND: + recurse_remove_peft_layers(self.text_encoder_2) + if getattr(self.text_encoder_2, "peft_config", None) is not None: del self.text_encoder_2.peft_config self.text_encoder_2._hf_peft_config_loaded = None - @classmethod - def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): - deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE) - - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj.lora_linear_layer = None - attn_module.k_proj.lora_linear_layer = None - attn_module.v_proj.lora_linear_layer = None - attn_module.out_proj.lora_linear_layer = None - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1.lora_linear_layer = None - mlp_module.fc2.lora_linear_layer = None - - @classmethod - def _modify_text_encoder( - cls, - text_encoder, - lora_scale=1, - network_alphas=None, - rank: Union[Dict[str, int], int] = 4, - dtype=None, - patch_mlp=False, - low_cpu_mem_usage=False, - ): - r""" - Monkey-patches the forward passes of attention modules of the text encoder. - """ - deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE) - - def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): - linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) - - lora_parameters.extend(model.lora_linear_layer.parameters()) - return model - - # First, remove any monkey-patch that might have been applied before - cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) - - lora_parameters = [] - network_alphas = {} if network_alphas is None else network_alphas - is_network_alphas_populated = len(network_alphas) > 0 - - for name, attn_module in text_encoder_attn_modules(text_encoder): - query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) - key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) - value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) - out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) - - if isinstance(rank, dict): - current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") - else: - current_rank = rank - - attn_module.q_proj = create_patched_linear_lora( - attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters - ) - attn_module.k_proj = create_patched_linear_lora( - attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters - ) - attn_module.v_proj = create_patched_linear_lora( - attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters - ) - attn_module.out_proj = create_patched_linear_lora( - attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters - ) - - if patch_mlp: - for name, mlp_module in text_encoder_mlp_modules(text_encoder): - fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None) - fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None) - - current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") - current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") - - mlp_module.fc1 = create_patched_linear_lora( - mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters - ) - mlp_module.fc2 = create_patched_linear_lora( - mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters - ) - - if is_network_alphas_populated and len(network_alphas) > 0: - raise ValueError( - f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" - ) - - return lora_parameters - @classmethod def save_lora_weights( cls, @@ -1039,6 +879,8 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ + from peft.tuners.tuners_utils import BaseTunerLayer + if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 if self.num_fused_loras > 1: @@ -1050,52 +892,26 @@ def fuse_lora( unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) - if USE_PEFT_BACKEND: - from peft.tuners.tuners_utils import BaseTunerLayer - - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): - merge_kwargs = {"safe_merge": safe_fusing} - - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - if lora_scale != 1.0: - module.scale_layer(lora_scale) - - # For BC with previous PEFT versions, we need to check the signature - # of the `merge` method to see if it supports the `adapter_names` argument. - supported_merge_kwargs = list(inspect.signature(module.merge).parameters) - if "adapter_names" in supported_merge_kwargs: - merge_kwargs["adapter_names"] = adapter_names - elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: - raise ValueError( - "The `adapter_names` argument is not supported with your PEFT version. " - "Please upgrade to the latest version of PEFT. `pip install -U peft`" - ) - - module.merge(**merge_kwargs) + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + merge_kwargs = {"safe_merge": safe_fusing} - else: - deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE) - - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs): - if "adapter_names" in kwargs and kwargs["adapter_names"] is not None: - raise ValueError( - "The `adapter_names` argument is not supported in your environment. Please switch to PEFT " - "backend to use this argument by installing latest PEFT and transformers." - " `pip install -U peft transformers`" - ) + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora(lora_scale, safe_fusing) - attn_module.k_proj._fuse_lora(lora_scale, safe_fusing) - attn_module.v_proj._fuse_lora(lora_scale, safe_fusing) - attn_module.out_proj._fuse_lora(lora_scale, safe_fusing) + # For BC with previous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. " + "Please upgrade to the latest version of PEFT. `pip install -U peft`" + ) - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora(lora_scale, safe_fusing) - mlp_module.fc2._fuse_lora(lora_scale, safe_fusing) + module.merge(**merge_kwargs) if fuse_text_encoder: if hasattr(self, "text_encoder"): @@ -1120,40 +936,18 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + from peft.tuners.tuners_utils import BaseTunerLayer + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: - if not USE_PEFT_BACKEND: - unet.unfuse_lora() - else: - from peft.tuners.tuners_utils import BaseTunerLayer - - for module in unet.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() - - if USE_PEFT_BACKEND: - from peft.tuners.tuners_utils import BaseTunerLayer - - def unfuse_text_encoder_lora(text_encoder): - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() - - else: - deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE) - - def unfuse_text_encoder_lora(text_encoder): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._unfuse_lora() - attn_module.k_proj._unfuse_lora() - attn_module.v_proj._unfuse_lora() - attn_module.out_proj._unfuse_lora() + for module in unet.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._unfuse_lora() - mlp_module.fc2._unfuse_lora() + def unfuse_text_encoder_lora(text_encoder): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() if unfuse_text_encoder: if hasattr(self, "text_encoder"): @@ -1434,6 +1228,9 @@ def load_lora_weights( kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. @@ -1538,17 +1335,13 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - if USE_PEFT_BACKEND: - recurse_remove_peft_layers(self.text_encoder) - # TODO: @younesbelkada handle this in transformers side - if getattr(self.text_encoder, "peft_config", None) is not None: - del self.text_encoder.peft_config - self.text_encoder._hf_peft_config_loaded = None - - recurse_remove_peft_layers(self.text_encoder_2) - if getattr(self.text_encoder_2, "peft_config", None) is not None: - del self.text_encoder_2.peft_config - self.text_encoder_2._hf_peft_config_loaded = None - else: - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) + recurse_remove_peft_layers(self.text_encoder) + # TODO: @younesbelkada handle this in transformers side + if getattr(self.text_encoder, "peft_config", None) is not None: + del self.text_encoder.peft_config + self.text_encoder._hf_peft_config_loaded = None + + recurse_remove_peft_layers(self.text_encoder_2) + if getattr(self.text_encoder_2, "peft_config", None) is not None: + del self.text_encoder_2.peft_config + self.text_encoder_2._hf_peft_config_loaded = None diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py deleted file mode 100644 index 148e551d6c1a..000000000000 --- a/tests/lora/test_lora_layers_old_backend.py +++ /dev/null @@ -1,2193 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import gc -import os -import random -import tempfile -import time -import unittest - -import numpy as np -import torch -import torch.nn as nn -from huggingface_hub.repocard import RepoCard -from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - -from diffusers import ( - AutoencoderKL, - ControlNetModel, - DDIMScheduler, - DiffusionPipeline, - EulerDiscreteScheduler, - PNDMScheduler, - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionXLControlNetPipeline, - StableDiffusionXLPipeline, - UNet2DConditionModel, - UNet3DConditionModel, -) -from diffusers.loaders import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin -from diffusers.models.attention_processor import ( - Attention, - AttnProcessor, - AttnProcessor2_0, - XFormersAttnProcessor, -) -from diffusers.models.lora import LoRALinearLayer -from diffusers.training_utils import unet_lora_state_dict -from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import ( - deprecate_after_peft_backend, - floats_tensor, - load_image, - nightly, - require_torch_gpu, - slow, - torch_device, -) - - -def text_encoder_attn_modules(text_encoder: nn.Module): - """Fetches the attention modules from `text_encoder`.""" - attn_modules = [] - - if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): - for i, layer in enumerate(text_encoder.text_model.encoder.layers): - name = f"text_model.encoder.layers.{i}.self_attn" - mod = layer.self_attn - attn_modules.append((name, mod)) - else: - raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") - - return attn_modules - - -def text_encoder_lora_state_dict(text_encoder: nn.Module): - """Returns the LoRA state dict of the `text_encoder`. Assumes that `_modify_text_encoder()` was already called on it.""" - state_dict = {} - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - -def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): - """Creates and returns the LoRA state dict for the UNet.""" - # So that we accidentally don't end up using the in-place modified UNet. - unet_lora_parameters = [] - - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, - out_features=attn_module.to_q.out_features, - rank=rank, - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, - out_features=attn_module.to_k.out_features, - rank=rank, - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, - out_features=attn_module.to_v.out_features, - rank=rank, - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=rank, - ) - ) - - if mock_weights: - with torch.no_grad(): - attn_module.to_q.lora_layer.up.weight += 1 - attn_module.to_k.lora_layer.up.weight += 1 - attn_module.to_v.lora_layer.up.weight += 1 - attn_module.to_out[0].lora_layer.up.weight += 1 - - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) - - unet_lora_sd = unet_lora_state_dict(unet) - # Unload LoRA. - unet.unload_lora() - - return unet_lora_parameters, unet_lora_sd - - -def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): - """Creates and returns the LoRA state dict for the 3D UNet.""" - for attn_processor_name in unet.attn_processors.keys(): - has_cross_attention = attn_processor_name.endswith("attn2.processor") and not ( - attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".") - ) - cross_attention_dim = unet.config.cross_attention_dim if has_cross_attention else None - - if attn_processor_name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif attn_processor_name.startswith("up_blocks"): - block_id = int(attn_processor_name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif attn_processor_name.startswith("down_blocks"): - block_id = int(attn_processor_name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - elif attn_processor_name.startswith("transformer_in"): - # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 - hidden_size = 8 * unet.config.attention_head_dim - - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=min(attn_module.to_q.in_features, hidden_size), - out_features=attn_module.to_q.out_features - if cross_attention_dim is None - else max(attn_module.to_q.out_features, cross_attention_dim), - rank=rank, - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=min(attn_module.to_k.in_features, hidden_size), - out_features=attn_module.to_k.out_features - if cross_attention_dim is None - else max(attn_module.to_k.out_features, cross_attention_dim), - rank=rank, - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=min(attn_module.to_v.in_features, hidden_size), - out_features=attn_module.to_v.out_features - if cross_attention_dim is None - else max(attn_module.to_v.out_features, cross_attention_dim), - rank=rank, - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=min(attn_module.to_out[0].in_features, hidden_size), - out_features=attn_module.to_out[0].out_features - if cross_attention_dim is None - else max(attn_module.to_out[0].out_features, cross_attention_dim), - rank=rank, - ) - ) - - if mock_weights: - with torch.no_grad(): - attn_module.to_q.lora_layer.up.weight += 1 - attn_module.to_k.lora_layer.up.weight += 1 - attn_module.to_v.lora_layer.up.weight += 1 - attn_module.to_out[0].lora_layer.up.weight += 1 - - unet_lora_sd = unet_lora_state_dict(unet) - - # Unload LoRA. - unet.unload_lora() - - return unet_lora_sd - - -def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): - """Randomizes the LoRA params if specified.""" - if not isinstance(lora_attn_parameters, dict): - with torch.no_grad(): - for parameter in lora_attn_parameters: - if randn_weight: - parameter[:] = torch.randn_like(parameter) * var - else: - torch.zero_(parameter) - else: - if randn_weight: - modified_state_dict = {k: torch.rand_like(v) * var for k, v in lora_attn_parameters.items()} - else: - modified_state_dict = {k: torch.zeros_like(v) * var for k, v in lora_attn_parameters.items()} - return modified_state_dict - - -def state_dicts_almost_equal(sd1, sd2): - sd1 = dict(sorted(sd1.items())) - sd2 = dict(sorted(sd2.items())) - - models_are_equal = True - for ten1, ten2 in zip(sd1.values(), sd2.values()): - if (ten1 - ten2).abs().max() > 1e-3: - models_are_equal = False - - return models_are_equal - - -@deprecate_after_peft_backend -class LoraLoaderMixinTests(unittest.TestCase): - lora_rank = 4 - - def get_dummy_components(self): - torch.manual_seed(0) - unet = UNet2DConditionModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, - ) - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - steps_offset=1, - ) - torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - ) - text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - ) - text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) - text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder( - text_encoder, dtype=torch.float32, rank=self.lora_rank - ) - text_encoder_lora_params = text_encoder_lora_state_dict(text_encoder) - # We call this to ensure that the effects of the in-place `_modify_text_encoder` have been erased. - LoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder) - - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, - "image_encoder": None, - } - lora_components = { - "unet_lora_raw_params": unet_lora_raw_params, - "unet_lora_params": unet_lora_params, - "text_encoder_lora_params": text_encoder_lora_params, - } - return pipeline_components, lora_components - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 10 - num_channels = 4 - sizes = (32, 32) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 2, - "guidance_scale": 6.0, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb - def get_dummy_tokens(self): - max_seq_length = 77 - - inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) - - prepared_inputs = {} - prepared_inputs["input_ids"] = inputs - return prepared_inputs - - def create_lora_weight_file(self, tmpdirname): - _, lora_components = self.get_dummy_components() - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - - @unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda") - def test_stable_diffusion_xformers_attn_processors(self): - # disable_full_determinism() - device = "cuda" # ensure determinism for the device-dependent torch.Generator - components, _ = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs() - - # run xformers attention - sd_pipe.enable_xformers_memory_efficient_attention() - image = sd_pipe(**inputs).images - assert image.shape == (1, 64, 64, 3) - - def test_stable_diffusion_lora(self): - components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - sd_pipe.unet.set_default_attn_processor() - - # forward 1 - _, _, inputs = self.get_dummy_inputs() - - output = sd_pipe(**inputs) - image = output.images - image_slice = image[0, -3:, -3:, -1] - - # set lora layers - sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"]) - - # forward 2 - _, _, inputs = self.get_dummy_inputs() - - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) - image = output.images - image_slice_1 = image[0, -3:, -3:, -1] - - # forward 3 - _, _, inputs = self.get_dummy_inputs() - - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) - image = output.images - image_slice_2 = image[0, -3:, -3:, -1] - - assert np.abs(image_slice - image_slice_1).max() < 1e-2 - assert np.abs(image_slice - image_slice_2).max() > 1e-2 - - def test_lora_save_load(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs() - - original_images = sd_pipe(**pipeline_inputs).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Outputs shouldn't match. - self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - - def test_lora_save_load_no_safe_serialization(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs() - - original_images = sd_pipe(**pipeline_inputs).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - safe_serialization=False, - ) - - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Outputs shouldn't match. - self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - - def test_text_encoder_lora_monkey_patch(self): - pipeline_components, _ = self.get_dummy_components() - pipe = StableDiffusionPipeline(**pipeline_components) - - dummy_tokens = self.get_dummy_tokens() - - # inference without lora - outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora.shape == (1, 77, 32) - - # monkey patch - text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - text_encoder_lora_params = set_lora_weights( - text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=False - ) - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=None, - text_encoder_lora_layers=text_encoder_lora_params, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.load_lora_weights(tmpdirname) - - # inference with lora - outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 32) - - assert torch.allclose( - outputs_without_lora, outputs_with_lora - ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - - # monkey patch - pipeline_components, _ = self.get_dummy_components() - pipe = StableDiffusionPipeline(**pipeline_components) - - text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - text_encoder_lora_params = set_lora_weights( - text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=True, var=0.1 - ) - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=None, - text_encoder_lora_layers=text_encoder_lora_params, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.load_lora_weights(tmpdirname) - - # inference with lora - outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 32) - - assert not torch.allclose( - outputs_without_lora, outputs_with_lora - ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" - - def test_text_encoder_lora_remove_monkey_patch(self): - pipeline_components, _ = self.get_dummy_components() - pipe = StableDiffusionPipeline(**pipeline_components) - - dummy_tokens = self.get_dummy_tokens() - - # inference without lora - outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora.shape == (1, 77, 32) - - # monkey patch - params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - params = set_lora_weights(text_encoder_lora_state_dict(pipe.text_encoder), var=0.1, randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=None, - text_encoder_lora_layers=params, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - pipe.load_lora_weights(tmpdirname) - - # inference with lora - outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 32) - - assert not torch.allclose( - outputs_without_lora, outputs_with_lora - ), "lora outputs should be different to without lora outputs" - - # remove monkey patch - pipe._remove_text_encoder_monkey_patch() - - # inference with removed lora - outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora_removed.shape == (1, 77, 32) - - assert torch.allclose( - outputs_without_lora, outputs_without_lora_removed - ), "remove lora monkey patch should restore the original outputs" - - def test_text_encoder_lora_scale(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs() - - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - lora_images_with_scale = sd_pipe(**pipeline_inputs, cross_attention_kwargs={"scale": 0.5}).images - lora_image_with_scale_slice = lora_images_with_scale[0, -3:, -3:, -1] - - # Outputs shouldn't match. - self.assertFalse( - torch.allclose(torch.from_numpy(lora_image_slice), torch.from_numpy(lora_image_with_scale_slice)) - ) - - def test_lora_unet_attn_processors(self): - with tempfile.TemporaryDirectory() as tmpdirname: - self.create_lora_weight_file(tmpdirname) - - pipeline_components, _ = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - # check if vanilla attention processors are used - for _, module in sd_pipe.unet.named_modules(): - if isinstance(module, Attention): - self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0)) - - # load LoRA weight file - sd_pipe.load_lora_weights(tmpdirname) - - # check if lora attention processors are used - for _, module in sd_pipe.unet.named_modules(): - if isinstance(module, Attention): - self.assertIsNotNone(module.to_q.lora_layer) - self.assertIsNotNone(module.to_k.lora_layer) - self.assertIsNotNone(module.to_v.lora_layer) - self.assertIsNotNone(module.to_out[0].lora_layer) - - def test_unload_lora_sd(self): - pipeline_components, lora_components = self.get_dummy_components() - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe.unet.set_default_attn_processor() - - original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Unload LoRA parameters. - sd_pipe.unload_lora_weights() - original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice_two = original_images_two[0, -3:, -3:, -1] - - assert not np.allclose( - orig_image_slice, lora_image_slice - ), "LoRA parameters should lead to a different image slice." - assert not np.allclose( - orig_image_slice_two, lora_image_slice - ), "LoRA parameters should lead to a different image slice." - assert np.allclose( - orig_image_slice, orig_image_slice_two, atol=1e-3 - ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." - - @unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU") - def test_lora_unet_attn_processors_with_xformers(self): - with tempfile.TemporaryDirectory() as tmpdirname: - self.create_lora_weight_file(tmpdirname) - - pipeline_components, _ = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - # enable XFormers - sd_pipe.enable_xformers_memory_efficient_attention() - - # check if xFormers attention processors are used - for _, module in sd_pipe.unet.named_modules(): - if isinstance(module, Attention): - self.assertIsInstance(module.processor, XFormersAttnProcessor) - - # load LoRA weight file - sd_pipe.load_lora_weights(tmpdirname) - - # check if lora attention processors are used - for _, module in sd_pipe.unet.named_modules(): - if isinstance(module, Attention): - self.assertIsNotNone(module.to_q.lora_layer) - self.assertIsNotNone(module.to_k.lora_layer) - self.assertIsNotNone(module.to_v.lora_layer) - self.assertIsNotNone(module.to_out[0].lora_layer) - - # unload lora weights - sd_pipe.unload_lora_weights() - - # check if attention processors are reverted back to xFormers - for _, module in sd_pipe.unet.named_modules(): - if isinstance(module, Attention): - self.assertIsInstance(module.processor, XFormersAttnProcessor) - - @unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU") - def test_lora_save_load_with_xformers(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs() - - # enable XFormers - sd_pipe.enable_xformers_memory_efficient_attention() - - original_images = sd_pipe(**pipeline_inputs).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - LoraLoaderMixin.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Outputs shouldn't match. - self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - - -@deprecate_after_peft_backend -class SDInpaintLoraMixinTests(unittest.TestCase): - lora_rank = 4 - - def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): - # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched - if output_pil: - # Get random floats in [0, 1] as image - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) - image = image.cpu().permute(0, 2, 3, 1)[0] - mask_image = torch.ones_like(image) - # Convert image and mask_image to [0, 255] - image = 255 * image - mask_image = 255 * mask_image - # Convert to PIL image - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((img_res, img_res)) - mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB").resize((img_res, img_res)) - else: - # Get random floats in [0, 1] as image with spatial size (img_res, img_res) - image = floats_tensor((1, 3, img_res, img_res), rng=random.Random(seed)).to(device) - # Convert image to [-1, 1] - init_image = 2.0 * image - 1.0 - mask_image = torch.ones((1, 1, img_res, img_res), device=device) - - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device=device).manual_seed(seed) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "image": init_image, - "mask_image": mask_image, - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 6.0, - "output_type": "numpy", - } - return inputs - - def get_dummy_components(self): - torch.manual_seed(0) - unet = UNet2DConditionModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=9, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, - ) - scheduler = PNDMScheduler(skip_prk_steps=True) - torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - ) - torch.manual_seed(0) - text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - ) - text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) - text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( - text_encoder, dtype=torch.float32, rank=self.lora_rank - ) - text_encoder_lora_params = set_lora_weights( - text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 - ) - - components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, - "image_encoder": None, - } - lora_components = { - "unet_lora_raw_params": unet_lora_raw_params, - "unet_lora_params": unet_lora_params, - "text_encoder_lora_params": text_encoder_lora_params, - } - return components, lora_components - - def test_stable_diffusion_inpaint_lora(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - - components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionInpaintPipeline(**components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - sd_pipe.unet.set_default_attn_processor() - - # forward 1 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs) - image = output.images - image_slice = image[0, -3:, -3:, -1] - - # set lora layers - sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"]) - - # forward 2 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) - image = output.images - image_slice_1 = image[0, -3:, -3:, -1] - - # forward 3 - inputs = self.get_dummy_inputs(device) - output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) - image = output.images - image_slice_2 = image[0, -3:, -3:, -1] - - assert np.abs(image_slice - image_slice_1).max() < 1e-2 - assert np.abs(image_slice - image_slice_2).max() > 1e-2 - - -@deprecate_after_peft_backend -class SDXLLoraLoaderMixinTests(unittest.TestCase): - lora_rank = 4 - - def get_dummy_components(self, modify_text_encoder=True): - torch.manual_seed(0) - unet = UNet2DConditionModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - # SD2-specific config below - attention_head_dim=(2, 4), - use_linear_projection=True, - addition_embed_type="text_time", - addition_time_embed_dim=8, - transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=80, # 6 * 8 + 32 - cross_attention_dim=64, - ) - scheduler = EulerDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - steps_offset=1, - beta_schedule="scaled_linear", - timestep_spacing="leading", - ) - torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - sample_size=128, - ) - torch.manual_seed(0) - text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - # SD2-specific config below - hidden_act="gelu", - projection_dim=32, - ) - text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - _, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) - - if modify_text_encoder: - _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( - text_encoder, dtype=torch.float32, rank=self.lora_rank - ) - text_encoder_lora_params = text_encoder_lora_state_dict(text_encoder) - StableDiffusionXLLoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder) - - _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( - text_encoder_2, dtype=torch.float32, rank=self.lora_rank - ) - text_encoder_two_lora_params = text_encoder_lora_state_dict(text_encoder_2) - StableDiffusionXLLoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder_2) - else: - text_encoder_lora_params = None - text_encoder_two_lora_params = None - - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "image_encoder": None, - "feature_extractor": None, - } - lora_components = { - "unet_lora_params": unet_lora_params, - "text_encoder_lora_params": text_encoder_lora_params, - "text_encoder_two_lora_params": text_encoder_two_lora_params, - } - return pipeline_components, lora_components - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 10 - num_channels = 4 - sizes = (32, 32) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "A painting of a squirrel eating a burger", - "num_inference_steps": 2, - "guidance_scale": 6.0, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - def test_lora_save_load(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs() - - original_images = sd_pipe(**pipeline_inputs).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Outputs shouldn't match. - self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - - def test_unload_lora_sdxl(self): - pipeline_components, lora_components = self.get_dummy_components() - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe.unet.set_default_attn_processor() - - original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(tmpdirname) - - lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Unload LoRA parameters. - sd_pipe.unload_lora_weights() - original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice_two = original_images_two[0, -3:, -3:, -1] - - assert not np.allclose( - orig_image_slice, lora_image_slice - ), "LoRA parameters should lead to a different image slice." - assert not np.allclose( - orig_image_slice_two, lora_image_slice - ), "LoRA parameters should lead to a different image slice." - assert np.allclose( - orig_image_slice, orig_image_slice_two, atol=1e-3 - ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." - - def test_load_lora_locally(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=False, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - - sd_pipe.unload_lora_weights() - - def test_text_encoder_lora_state_dict_unchanged(self): - pipeline_components, lora_components = self.get_dummy_components(modify_text_encoder=False) - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - - text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys()) - text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys()) - - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - # Modify the text encoder. - _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( - sd_pipe.text_encoder, dtype=torch.float32, rank=self.lora_rank - ) - lora_components["text_encoder_lora_params"] = set_lora_weights( - text_encoder_lora_state_dict(sd_pipe.text_encoder), randn_weight=True, var=0.1 - ) - _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( - sd_pipe.text_encoder_2, dtype=torch.float32, rank=self.lora_rank - ) - lora_components["text_encoder_two_lora_params"] = set_lora_weights( - text_encoder_lora_state_dict(sd_pipe.text_encoder_2), randn_weight=True, var=0.1 - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=False, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) - - text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys()) - text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) - - sd_pipe.unload_lora_weights() - - text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys()) - text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) - - # default & unloaded LoRA weights should have identical state_dicts - assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 - # default & loaded LoRA weights should NOT have identical state_dicts - assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 - - # default & unloaded LoRA weights should have identical state_dicts - assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 - # default & loaded LoRA weights should NOT have identical state_dicts - assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2 - - def test_load_lora_locally_safetensors(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.unload_lora_weights() - - def test_lora_fuse_nan(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - # corrupt one LoRA weight with `inf` values - with torch.no_grad(): - sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float( - "NaN" - ) - - # with `safe_fusing=True` we should see an Error - with self.assertRaises(ValueError): - sd_pipe.fuse_lora(safe_fusing=True) - - # without we should not see an error, but every image will be black - sd_pipe.fuse_lora(safe_fusing=False) - - out = sd_pipe("test", num_inference_steps=2, output_type="np").images - - assert np.isnan(out).all() - - def test_lora_fusion(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.fuse_lora() - lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3)) - - def test_unfuse_lora(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - sd_pipe.unet.set_default_attn_processor() - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.fuse_lora() - lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Reverse LoRA fusion. - sd_pipe.unfuse_lora() - original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - orig_image_slice_two = original_images[0, -3:, -3:, -1] - - assert not np.allclose( - orig_image_slice, lora_image_slice - ), "Fusion of LoRAs should lead to a different image slice." - assert not np.allclose( - orig_image_slice_two, lora_image_slice - ), "Fusion of LoRAs should lead to a different image slice." - assert np.allclose( - orig_image_slice, orig_image_slice_two, atol=1e-3 - ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters." - - def test_lora_fusion_is_not_affected_by_unloading(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.fuse_lora() - lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice = lora_images[0, -3:, -3:, -1] - - # Unload LoRA parameters. - sd_pipe.unload_lora_weights() - images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1] - - assert ( - np.abs(lora_image_slice - images_with_unloaded_lora_slice).max() < 2e-1 - ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." - - def test_fuse_lora_with_different_scales(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.fuse_lora(lora_scale=1.0) - lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] - - # Reverse LoRA fusion. - sd_pipe.unfuse_lora() - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.fuse_lora(lora_scale=0.5) - lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] - - assert not np.allclose( - lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 - ), "Different LoRA scales should influence the outputs accordingly." - - def test_with_different_scales(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - sd_pipe.unet.set_default_attn_processor() - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - original_imagee_slice = original_images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] - - lora_images_scale_0_5 = sd_pipe( - **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - ).images - lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] - - lora_images_scale_0_0 = sd_pipe( - **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - ).images - lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1] - - assert not np.allclose( - lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 - ), "Different LoRA scales should influence the outputs accordingly." - - assert np.allclose( - original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 - ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." - - def test_with_different_scales_fusion_equivalence(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - sd_pipe.unet.set_default_attn_processor() - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - images_slice = images[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - lora_images_scale_0_5 = sd_pipe( - **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - ).images - lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] - - sd_pipe.fuse_lora(lora_scale=0.5) - lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1] - - assert np.allclose( - lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03 - ), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale." - - sd_pipe.unfuse_lora() - images_unfused = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - images_slice_unfused = images_unfused[0, -3:, -3:, -1] - - assert np.allclose(images_slice, images_slice_unfused, atol=1e-03), "Unfused should match no LoRA" - - assert not np.allclose( - images_slice, lora_image_slice_scale_0_5, atol=1e-03 - ), "0.5 scale and no scale shouldn't match" - - def test_save_load_fused_lora_modules(self): - pipeline_components, lora_components = self.get_dummy_components() - sd_pipe = StableDiffusionXLPipeline(**pipeline_components) - sd_pipe = sd_pipe.to(torch_device) - sd_pipe.set_progress_bar_config(disable=None) - - _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_params"], - text_encoder_lora_layers=lora_components["text_encoder_lora_params"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - sd_pipe.fuse_lora() - lora_images_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images - lora_image_slice_fusion = lora_images_fusion[0, -3:, -3:, -1] - - with tempfile.TemporaryDirectory() as tmpdirname: - sd_pipe.save_pretrained(tmpdirname) - sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname).to(torch_device) - - loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images - loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1] - - assert np.allclose( - lora_image_slice_fusion, loaded_lora_image_slice, atol=1e-03 - ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." - - -@deprecate_after_peft_backend -class UNet2DConditionLoRAModelTests(unittest.TestCase): - model_class = UNet2DConditionModel - main_input_name = "sample" - lora_rank = 4 - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 32), rng=random.Random(0)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} - - @property - def input_shape(self): - return (4, 32, 32) - - @property - def output_shape(self): - return (4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 64), - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 32, - "attention_head_dim": 8, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 2, - "sample_size": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_lora_at_different_scales(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - sample1 = model(**inputs_dict).sample - - _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) - - # make sure we can set a list of attention processors - model.load_attn_procs(lora_params) - model.to(torch_device) - - with torch.no_grad(): - sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample1 - sample2).abs().max() < 3e-3 - assert (sample3 - sample4).abs().max() < 3e-3 - - # sample 2 and sample 3 should be different - assert (sample2 - sample3).abs().max() > 1e-4 - - def test_lora_on_off(self, expected_max_diff=1e-3): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - old_sample = model(**inputs_dict).sample - - _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) - model.load_attn_procs(lora_params) - - with torch.no_grad(): - sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - - # Unload LoRA. - model.unload_lora() - - with torch.no_grad(): - new_sample = model(**inputs_dict).sample - - max_diff_new_sample = (sample - new_sample).abs().max() - max_diff_old_sample = (sample - old_sample).abs().max() - - assert max_diff_new_sample < expected_max_diff - assert max_diff_old_sample < expected_max_diff - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_lora_xformers_on_off(self, expected_max_diff=6e-4): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = (8, 16) - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) - model.load_attn_procs(lora_params) - - # default - with torch.no_grad(): - sample = model(**inputs_dict).sample - - model.enable_xformers_memory_efficient_attention() - on_sample = model(**inputs_dict).sample - - model.disable_xformers_memory_efficient_attention() - off_sample = model(**inputs_dict).sample - - max_diff_on_sample = (sample - on_sample).abs().max() - max_diff_off_sample = (sample - off_sample).abs().max() - - assert max_diff_on_sample < expected_max_diff - assert max_diff_off_sample < expected_max_diff - - -@deprecate_after_peft_backend -class UNet3DConditionLoRAModelTests(unittest.TestCase): - model_class = UNet3DConditionModel - main_input_name = "sample" - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - num_frames = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_channels, num_frames) + sizes, rng=random.Random(0)).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 32), rng=random.Random(0)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} - - @property - def input_shape(self): - return (4, 4, 32, 32) - - @property - def output_shape(self): - return (4, 4, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": (32, 64), - "down_block_types": ( - "CrossAttnDownBlock3D", - "DownBlock3D", - ), - "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), - "cross_attention_dim": 32, - "attention_head_dim": 8, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 1, - "sample_size": 32, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_lora_at_different_scales(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["attention_head_dim"] = 8 - - model = self.model_class(**init_dict) - model.to(torch_device) - - with torch.no_grad(): - sample1 = model(**inputs_dict).sample - - unet_lora_params = create_3d_unet_lora_layers(model) - - # make sure we can set a list of attention processors - model.load_attn_procs(unet_lora_params) - model.to(torch_device) - - with torch.no_grad(): - sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - - assert (sample1 - sample2).abs().max() < 3e-3 - assert (sample3 - sample4).abs().max() < 3e-3 - - # sample 2 and sample 3 should be different - assert (sample2 - sample3).abs().max() > 3e-3 - - -@slow -@deprecate_after_peft_backend -@require_torch_gpu -class LoraIntegrationTests(unittest.TestCase): - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def test_dreambooth_old_format(self): - generator = torch.Generator("cpu").manual_seed(0) - - lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example" - card = RepoCard.load(lora_model_id) - base_model_id = card.data.to_dict()["base_model"] - - pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) - pipe = pipe.to(torch_device) - pipe.load_lora_weights(lora_model_id) - - images = pipe( - "A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - - expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785]) - - self.assertTrue(np.allclose(images, expected, atol=1e-4)) - - def test_dreambooth_text_encoder_new_format(self): - generator = torch.Generator().manual_seed(0) - - lora_model_id = "hf-internal-testing/lora-trained" - card = RepoCard.load(lora_model_id) - base_model_id = card.data.to_dict()["base_model"] - - pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) - pipe = pipe.to(torch_device) - pipe.load_lora_weights(lora_model_id) - - images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images - - images = images[0, -3:, -3:, -1].flatten() - - expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359]) - - self.assertTrue(np.allclose(images, expected, atol=1e-4)) - - def test_a1111(self): - generator = torch.Generator().manual_seed(0) - - pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to( - torch_device - ) - lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" - lora_filename = "light_and_shadow.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_lycoris(self): - generator = torch.Generator().manual_seed(0) - - pipe = StableDiffusionPipeline.from_pretrained( - "hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16" - ).to(torch_device) - lora_model_id = "hf-internal-testing/edgLycorisMugler-light" - lora_filename = "edgLycorisMugler-light.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_a1111_with_model_cpu_offload(self): - generator = torch.Generator().manual_seed(0) - - pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) - pipe.enable_model_cpu_offload() - lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" - lora_filename = "light_and_shadow.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_a1111_with_sequential_cpu_offload(self): - generator = torch.Generator().manual_seed(0) - - pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) - pipe.enable_sequential_cpu_offload() - lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" - lora_filename = "light_and_shadow.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_kohya_sd_v15_with_higher_dimensions(self): - generator = torch.Generator().manual_seed(0) - - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( - torch_device - ) - lora_model_id = "hf-internal-testing/urushisato-lora" - lora_filename = "urushisato_v15.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.7165, 0.6616, 0.5833, 0.7504, 0.6718, 0.587, 0.6871, 0.6361, 0.5694]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_vanilla_funetuning(self): - generator = torch.Generator().manual_seed(0) - - lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4" - card = RepoCard.load(lora_model_id) - base_model_id = card.data.to_dict()["base_model"] - - pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) - pipe = pipe.to(torch_device) - pipe.load_lora_weights(lora_model_id) - - images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images - - images = images[0, -3:, -3:, -1].flatten() - - expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) - - self.assertTrue(np.allclose(images, expected, atol=1e-4)) - - def test_unload_kohya_lora(self): - generator = torch.manual_seed(0) - prompt = "masterpiece, best quality, mountain" - num_inference_steps = 2 - - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( - torch_device - ) - initial_images = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - initial_images = initial_images[0, -3:, -3:, -1].flatten() - - lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" - lora_filename = "Colored_Icons_by_vizsumit.safetensors" - - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - generator = torch.manual_seed(0) - lora_images = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - lora_images = lora_images[0, -3:, -3:, -1].flatten() - - pipe.unload_lora_weights() - generator = torch.manual_seed(0) - unloaded_lora_images = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() - - self.assertFalse(np.allclose(initial_images, lora_images)) - self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) - - def test_load_unload_load_kohya_lora(self): - # This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded - # without introducing any side-effects. Even though the test uses a Kohya-style - # LoRA, the underlying adapter handling mechanism is format-agnostic. - generator = torch.manual_seed(0) - prompt = "masterpiece, best quality, mountain" - num_inference_steps = 2 - - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( - torch_device - ) - initial_images = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - initial_images = initial_images[0, -3:, -3:, -1].flatten() - - lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" - lora_filename = "Colored_Icons_by_vizsumit.safetensors" - - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - generator = torch.manual_seed(0) - lora_images = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - lora_images = lora_images[0, -3:, -3:, -1].flatten() - - pipe.unload_lora_weights() - generator = torch.manual_seed(0) - unloaded_lora_images = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() - - self.assertFalse(np.allclose(initial_images, lora_images)) - self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) - - # make sure we can load a LoRA again after unloading and they don't have - # any undesired effects. - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - generator = torch.manual_seed(0) - lora_images_again = pipe( - prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps - ).images - lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() - - self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) - - def test_sdxl_0_9_lora_one(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") - lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora" - lora_filename = "daiton-xl-lora-test.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.enable_model_cpu_offload() - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_sdxl_0_9_lora_two(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") - lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora" - lora_filename = "saijo.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.enable_model_cpu_offload() - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_sdxl_0_9_lora_three(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") - lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora" - lora_filename = "kame_sdxl_v2-000020-16rank.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.enable_model_cpu_offload() - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468]) - - self.assertTrue(np.allclose(images, expected, atol=5e-3)) - - def test_sdxl_1_0_lora(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - pipe.enable_model_cpu_offload() - lora_model_id = "hf-internal-testing/sdxl-1.0-lora" - lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) - - self.assertTrue(np.allclose(images, expected, atol=1e-4)) - - def test_sdxl_1_0_lora_fusion(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - lora_model_id = "hf-internal-testing/sdxl-1.0-lora" - lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.fuse_lora() - pipe.enable_model_cpu_offload() - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - # This way we also test equivalence between LoRA fusion and the non-fusion behaviour. - expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) - - self.assertTrue(np.allclose(images, expected, atol=1e-4)) - - def test_sdxl_1_0_lora_unfusion(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - lora_model_id = "hf-internal-testing/sdxl-1.0-lora" - lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.fuse_lora() - pipe.enable_model_cpu_offload() - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - images_with_fusion = images[0, -3:, -3:, -1].flatten() - - pipe.unfuse_lora() - generator = torch.Generator().manual_seed(0) - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - images_without_fusion = images[0, -3:, -3:, -1].flatten() - - self.assertFalse(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3)) - - def test_sdxl_1_0_lora_unfusion_effectivity(self): - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - pipe.enable_model_cpu_offload() - - generator = torch.Generator().manual_seed(0) - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - original_image_slice = images[0, -3:, -3:, -1].flatten() - - lora_model_id = "hf-internal-testing/sdxl-1.0-lora" - lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.fuse_lora() - - generator = torch.Generator().manual_seed(0) - _ = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - pipe.unfuse_lora() - generator = torch.Generator().manual_seed(0) - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - images_without_fusion_slice = images[0, -3:, -3:, -1].flatten() - - self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3)) - - def test_sdxl_1_0_lora_fusion_efficiency(self): - generator = torch.Generator().manual_seed(0) - lora_model_id = "hf-internal-testing/sdxl-1.0-lora" - lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16) - pipe.enable_model_cpu_offload() - - start_time = time.time() - for _ in range(3): - pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - end_time = time.time() - elapsed_time_non_fusion = end_time - start_time - - del pipe - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16) - pipe.fuse_lora() - pipe.enable_model_cpu_offload() - - generator = torch.Generator().manual_seed(0) - start_time = time.time() - for _ in range(3): - pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - end_time = time.time() - elapsed_time_fusion = end_time - start_time - - self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) - - def test_sdxl_1_0_last_ben(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - pipe.enable_model_cpu_offload() - lora_model_id = "TheLastBen/Papercut_SDXL" - lora_filename = "papercut.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_sdxl_1_0_fuse_unfuse_all(self): - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) - text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) - text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) - unet_sd = copy.deepcopy(pipe.unet.state_dict()) - - pipe.load_lora_weights( - "davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16 - ) - pipe.fuse_lora() - pipe.unload_lora_weights() - pipe.unfuse_lora() - - assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict()) - assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict()) - assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict()) - - def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): - generator = torch.Generator().manual_seed(0) - - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - pipe.enable_sequential_cpu_offload() - lora_model_id = "hf-internal-testing/sdxl-1.0-lora" - lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - - images = images[0, -3:, -3:, -1].flatten() - expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) - - self.assertTrue(np.allclose(images, expected, atol=1e-3)) - - def test_canny_lora(self): - controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") - - pipe = StableDiffusionXLControlNetPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet - ) - pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") - pipe.enable_sequential_cpu_offload() - - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "corgi" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) - - images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images - - assert images[0].shape == (768, 512, 3) - - original_image = images[0, -3:, -3:, -1].flatten() - expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) - assert np.allclose(original_image, expected_image, atol=1e-04) - - @nightly - def test_sequential_fuse_unfuse(self): - pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") - - # 1. round - pipe.load_lora_weights("Pclanglais/TintinIA") - pipe.fuse_lora() - - generator = torch.Generator().manual_seed(0) - images = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - image_slice = images[0, -3:, -3:, -1].flatten() - - pipe.unfuse_lora() - - # 2. round - pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style") - pipe.fuse_lora() - pipe.unfuse_lora() - - # 3. round - pipe.load_lora_weights("ostris/crayon_style_lora_sdxl") - pipe.fuse_lora() - pipe.unfuse_lora() - - # 4. back to 1st round - pipe.load_lora_weights("Pclanglais/TintinIA") - pipe.fuse_lora() - - generator = torch.Generator().manual_seed(0) - images_2 = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 - ).images - image_slice_2 = images_2[0, -3:, -3:, -1].flatten() - - self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3)) From eb90e96d50badade2474e40429152a732338cc83 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 7 Feb 2024 13:19:12 +0530 Subject: [PATCH 4/4] Revert "remove old lora backend" This reverts commit adcddf6ba421f847e7da2a0ce57b9456cae43356. --- src/diffusers/loaders/lora.py | 397 +++- tests/lora/test_lora_layers_old_backend.py | 2193 ++++++++++++++++++++ 2 files changed, 2495 insertions(+), 95 deletions(-) create mode 100644 tests/lora/test_lora_layers_old_backend.py diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 6e0e9af51740..922c98b98bf4 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect import os +from contextlib import nullcontext from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -25,7 +26,7 @@ from torch import nn from .. import __version__ -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -50,9 +51,10 @@ if is_transformers_available(): from transformers import PreTrainedModel - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules if is_accelerate_available(): + from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) @@ -104,9 +106,6 @@ def load_lora_weights( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) @@ -398,11 +397,6 @@ def load_lora_into_unet( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as @@ -433,7 +427,9 @@ def load_lora_into_unet( warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." logger.warn(warn_message) - if len(state_dict.keys()) > 0: + if USE_PEFT_BACKEND and len(state_dict.keys()) > 0: + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + if adapter_name in getattr(unet, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." @@ -522,11 +518,6 @@ def load_lora_into_text_encoder( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig - low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -548,21 +539,34 @@ def load_lora_into_text_encoder( rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_B.weight" - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - if patch_mlp: - for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_B.weight" - rank_key_fc2 = f"{name}.fc2.lora_B.weight" - - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + if USE_PEFT_BACKEND: + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + rank_key = f"{name}.out_proj.lora_B.weight" + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + if patch_mlp: + for name, _ in text_encoder_mlp_modules(text_encoder): + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" + + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + else: + for name, _ in text_encoder_attn_modules(text_encoder): + rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" + rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) + + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + if patch_mlp: + for name, _ in text_encoder_mlp_modules(text_encoder): + rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" + rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" + rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] if network_alphas is not None: alpha_keys = [ @@ -572,25 +576,84 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - lora_config = LoraConfig(**lora_config_kwargs) + if USE_PEFT_BACKEND: + from peft import LoraConfig - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, is_unet=False + ) + lora_config = LoraConfig(**lora_config_kwargs) - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - ) + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + else: + cls._modify_text_encoder( + text_encoder, + lora_scale, + network_alphas, + rank=rank, + patch_mlp=patch_mlp, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + is_pipeline_offloaded = _pipeline is not None and any( + isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") + for c in _pipeline.components.values() + ) + if is_pipeline_offloaded and low_cpu_mem_usage: + low_cpu_mem_usage = True + logger.info( + f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced." + ) + + if low_cpu_mem_usage: + device = next(iter(text_encoder_lora_state_dict.values())).device + dtype = next(iter(text_encoder_lora_state_dict.values())).dtype + unexpected_keys = load_model_dict_into_meta( + text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype + ) + else: + load_state_dict_results = text_encoder.load_state_dict( + text_encoder_lora_state_dict, strict=False + ) + unexpected_keys = load_state_dict_results.unexpected_keys - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) + if len(unexpected_keys) != 0: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) + + # float: return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 def _remove_text_encoder_monkey_patch(self): + if USE_PEFT_BACKEND: + remove_method = recurse_remove_peft_layers + else: + remove_method = self._remove_text_encoder_monkey_patch_classmethod + if hasattr(self, "text_encoder"): - recurse_remove_peft_layers(self.text_encoder) - if getattr(self.text_encoder, "peft_config", None) is not None: + remove_method(self.text_encoder) + + # In case text encoder have no Lora attached + if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None: del self.text_encoder.peft_config self.text_encoder._hf_peft_config_loaded = None - if hasattr(self, "text_encoder_2"): - recurse_remove_peft_layers(self.text_encoder_2) - if getattr(self.text_encoder_2, "peft_config", None) is not None: + remove_method(self.text_encoder_2) + if USE_PEFT_BACKEND: del self.text_encoder_2.peft_config self.text_encoder_2._hf_peft_config_loaded = None + @classmethod + def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): + deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE) + + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_linear_layer = None + attn_module.k_proj.lora_linear_layer = None + attn_module.v_proj.lora_linear_layer = None + attn_module.out_proj.lora_linear_layer = None + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_linear_layer = None + mlp_module.fc2.lora_linear_layer = None + + @classmethod + def _modify_text_encoder( + cls, + text_encoder, + lora_scale=1, + network_alphas=None, + rank: Union[Dict[str, int], int] = 4, + dtype=None, + patch_mlp=False, + low_cpu_mem_usage=False, + ): + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + """ + deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE) + + def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): + linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) + + lora_parameters.extend(model.lora_linear_layer.parameters()) + return model + + # First, remove any monkey-patch that might have been applied before + cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) + + lora_parameters = [] + network_alphas = {} if network_alphas is None else network_alphas + is_network_alphas_populated = len(network_alphas) > 0 + + for name, attn_module in text_encoder_attn_modules(text_encoder): + query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) + key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) + value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) + out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) + + if isinstance(rank, dict): + current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") + else: + current_rank = rank + + attn_module.q_proj = create_patched_linear_lora( + attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters + ) + attn_module.k_proj = create_patched_linear_lora( + attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters + ) + attn_module.v_proj = create_patched_linear_lora( + attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters + ) + attn_module.out_proj = create_patched_linear_lora( + attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters + ) + + if patch_mlp: + for name, mlp_module in text_encoder_mlp_modules(text_encoder): + fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None) + fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None) + + current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") + current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") + + mlp_module.fc1 = create_patched_linear_lora( + mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters + ) + mlp_module.fc2 = create_patched_linear_lora( + mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters + ) + + if is_network_alphas_populated and len(network_alphas) > 0: + raise ValueError( + f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" + ) + + return lora_parameters + @classmethod def save_lora_weights( cls, @@ -879,8 +1039,6 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ - from peft.tuners.tuners_utils import BaseTunerLayer - if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 if self.num_fused_loras > 1: @@ -892,26 +1050,52 @@ def fuse_lora( unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): - merge_kwargs = {"safe_merge": safe_fusing} + if USE_PEFT_BACKEND: + from peft.tuners.tuners_utils import BaseTunerLayer - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - if lora_scale != 1.0: - module.scale_layer(lora_scale) + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + merge_kwargs = {"safe_merge": safe_fusing} - # For BC with previous PEFT versions, we need to check the signature - # of the `merge` method to see if it supports the `adapter_names` argument. - supported_merge_kwargs = list(inspect.signature(module.merge).parameters) - if "adapter_names" in supported_merge_kwargs: - merge_kwargs["adapter_names"] = adapter_names - elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: - raise ValueError( - "The `adapter_names` argument is not supported with your PEFT version. " - "Please upgrade to the latest version of PEFT. `pip install -U peft`" - ) + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + + # For BC with previous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. " + "Please upgrade to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) - module.merge(**merge_kwargs) + else: + deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE) + + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs): + if "adapter_names" in kwargs and kwargs["adapter_names"] is not None: + raise ValueError( + "The `adapter_names` argument is not supported in your environment. Please switch to PEFT " + "backend to use this argument by installing latest PEFT and transformers." + " `pip install -U peft transformers`" + ) + + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._fuse_lora(lora_scale, safe_fusing) + attn_module.k_proj._fuse_lora(lora_scale, safe_fusing) + attn_module.v_proj._fuse_lora(lora_scale, safe_fusing) + attn_module.out_proj._fuse_lora(lora_scale, safe_fusing) + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._fuse_lora(lora_scale, safe_fusing) + mlp_module.fc2._fuse_lora(lora_scale, safe_fusing) if fuse_text_encoder: if hasattr(self, "text_encoder"): @@ -936,18 +1120,40 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - from peft.tuners.tuners_utils import BaseTunerLayer - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if unfuse_unet: - for module in unet.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() + if not USE_PEFT_BACKEND: + unet.unfuse_lora() + else: + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in unet.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + if USE_PEFT_BACKEND: + from peft.tuners.tuners_utils import BaseTunerLayer + + def unfuse_text_encoder_lora(text_encoder): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + else: + deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE) + + def unfuse_text_encoder_lora(text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._unfuse_lora() + attn_module.k_proj._unfuse_lora() + attn_module.v_proj._unfuse_lora() + attn_module.out_proj._unfuse_lora() - def unfuse_text_encoder_lora(text_encoder): - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._unfuse_lora() + mlp_module.fc2._unfuse_lora() if unfuse_text_encoder: if hasattr(self, "text_encoder"): @@ -1228,9 +1434,6 @@ def load_lora_weights( kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. @@ -1335,13 +1538,17 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - recurse_remove_peft_layers(self.text_encoder) - # TODO: @younesbelkada handle this in transformers side - if getattr(self.text_encoder, "peft_config", None) is not None: - del self.text_encoder.peft_config - self.text_encoder._hf_peft_config_loaded = None - - recurse_remove_peft_layers(self.text_encoder_2) - if getattr(self.text_encoder_2, "peft_config", None) is not None: - del self.text_encoder_2.peft_config - self.text_encoder_2._hf_peft_config_loaded = None + if USE_PEFT_BACKEND: + recurse_remove_peft_layers(self.text_encoder) + # TODO: @younesbelkada handle this in transformers side + if getattr(self.text_encoder, "peft_config", None) is not None: + del self.text_encoder.peft_config + self.text_encoder._hf_peft_config_loaded = None + + recurse_remove_peft_layers(self.text_encoder_2) + if getattr(self.text_encoder_2, "peft_config", None) is not None: + del self.text_encoder_2.peft_config + self.text_encoder_2._hf_peft_config_loaded = None + else: + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py new file mode 100644 index 000000000000..148e551d6c1a --- /dev/null +++ b/tests/lora/test_lora_layers_old_backend.py @@ -0,0 +1,2193 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import gc +import os +import random +import tempfile +import time +import unittest + +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub.repocard import RepoCard +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + DiffusionPipeline, + EulerDiscreteScheduler, + PNDMScheduler, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLControlNetPipeline, + StableDiffusionXLPipeline, + UNet2DConditionModel, + UNet3DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, +) +from diffusers.models.lora import LoRALinearLayer +from diffusers.training_utils import unet_lora_state_dict +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + deprecate_after_peft_backend, + floats_tensor, + load_image, + nightly, + require_torch_gpu, + slow, + torch_device, +) + + +def text_encoder_attn_modules(text_encoder: nn.Module): + """Fetches the attention modules from `text_encoder`.""" + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_lora_state_dict(text_encoder: nn.Module): + """Returns the LoRA state dict of the `text_encoder`. Assumes that `_modify_text_encoder()` was already called on it.""" + state_dict = {} + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + +def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + """Creates and returns the LoRA state dict for the UNet.""" + # So that we accidentally don't end up using the in-place modified UNet. + unet_lora_parameters = [] + + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, + out_features=attn_module.to_q.out_features, + rank=rank, + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, + out_features=attn_module.to_k.out_features, + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, + out_features=attn_module.to_v.out_features, + rank=rank, + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=rank, + ) + ) + + if mock_weights: + with torch.no_grad(): + attn_module.to_q.lora_layer.up.weight += 1 + attn_module.to_k.lora_layer.up.weight += 1 + attn_module.to_v.lora_layer.up.weight += 1 + attn_module.to_out[0].lora_layer.up.weight += 1 + + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + + unet_lora_sd = unet_lora_state_dict(unet) + # Unload LoRA. + unet.unload_lora() + + return unet_lora_parameters, unet_lora_sd + + +def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True): + """Creates and returns the LoRA state dict for the 3D UNet.""" + for attn_processor_name in unet.attn_processors.keys(): + has_cross_attention = attn_processor_name.endswith("attn2.processor") and not ( + attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".") + ) + cross_attention_dim = unet.config.cross_attention_dim if has_cross_attention else None + + if attn_processor_name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif attn_processor_name.startswith("up_blocks"): + block_id = int(attn_processor_name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif attn_processor_name.startswith("down_blocks"): + block_id = int(attn_processor_name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + elif attn_processor_name.startswith("transformer_in"): + # Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148 + hidden_size = 8 * unet.config.attention_head_dim + + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_q.in_features, hidden_size), + out_features=attn_module.to_q.out_features + if cross_attention_dim is None + else max(attn_module.to_q.out_features, cross_attention_dim), + rank=rank, + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_k.in_features, hidden_size), + out_features=attn_module.to_k.out_features + if cross_attention_dim is None + else max(attn_module.to_k.out_features, cross_attention_dim), + rank=rank, + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_v.in_features, hidden_size), + out_features=attn_module.to_v.out_features + if cross_attention_dim is None + else max(attn_module.to_v.out_features, cross_attention_dim), + rank=rank, + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=min(attn_module.to_out[0].in_features, hidden_size), + out_features=attn_module.to_out[0].out_features + if cross_attention_dim is None + else max(attn_module.to_out[0].out_features, cross_attention_dim), + rank=rank, + ) + ) + + if mock_weights: + with torch.no_grad(): + attn_module.to_q.lora_layer.up.weight += 1 + attn_module.to_k.lora_layer.up.weight += 1 + attn_module.to_v.lora_layer.up.weight += 1 + attn_module.to_out[0].lora_layer.up.weight += 1 + + unet_lora_sd = unet_lora_state_dict(unet) + + # Unload LoRA. + unet.unload_lora() + + return unet_lora_sd + + +def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): + """Randomizes the LoRA params if specified.""" + if not isinstance(lora_attn_parameters, dict): + with torch.no_grad(): + for parameter in lora_attn_parameters: + if randn_weight: + parameter[:] = torch.randn_like(parameter) * var + else: + torch.zero_(parameter) + else: + if randn_weight: + modified_state_dict = {k: torch.rand_like(v) * var for k, v in lora_attn_parameters.items()} + else: + modified_state_dict = {k: torch.zeros_like(v) * var for k, v in lora_attn_parameters.items()} + return modified_state_dict + + +def state_dicts_almost_equal(sd1, sd2): + sd1 = dict(sorted(sd1.items())) + sd2 = dict(sorted(sd2.items())) + + models_are_equal = True + for ten1, ten2 in zip(sd1.values(), sd2.values()): + if (ten1 - ten2).abs().max() > 1e-3: + models_are_equal = False + + return models_are_equal + + +@deprecate_after_peft_backend +class LoraLoaderMixinTests(unittest.TestCase): + lora_rank = 4 + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) + text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_lora_params = text_encoder_lora_state_dict(text_encoder) + # We call this to ensure that the effects of the in-place `_modify_text_encoder` have been erased. + LoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + "image_encoder": None, + } + lora_components = { + "unet_lora_raw_params": unet_lora_raw_params, + "unet_lora_params": unet_lora_params, + "text_encoder_lora_params": text_encoder_lora_params, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def create_lora_weight_file(self, tmpdirname): + _, lora_components = self.get_dummy_components() + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + + @unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda") + def test_stable_diffusion_xformers_attn_processors(self): + # disable_full_determinism() + device = "cuda" # ensure determinism for the device-dependent torch.Generator + components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs() + + # run xformers attention + sd_pipe.enable_xformers_memory_efficient_attention() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + def test_stable_diffusion_lora(self): + components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() + + # forward 1 + _, _, inputs = self.get_dummy_inputs() + + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"]) + + # forward 2 + _, _, inputs = self.get_dummy_inputs() + + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + _, _, inputs = self.get_dummy_inputs() + + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_lora_save_load_no_safe_serialization(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + safe_serialization=False, + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_text_encoder_lora_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # monkey patch + text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=False + ) + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=None, + text_encoder_lora_layers=text_encoder_lora_params, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.load_lora_weights(tmpdirname) + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # monkey patch + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=True, var=0.1 + ) + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=None, + text_encoder_lora_layers=text_encoder_lora_params, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.load_lora_weights(tmpdirname) + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" + + def test_text_encoder_lora_remove_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # monkey patch + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + params = set_lora_weights(text_encoder_lora_state_dict(pipe.text_encoder), var=0.1, randn_weight=True) + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=None, + text_encoder_lora_layers=params, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + pipe.load_lora_weights(tmpdirname) + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora outputs should be different to without lora outputs" + + # remove monkey patch + pipe._remove_text_encoder_monkey_patch() + + # inference with removed lora + outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora_removed.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_without_lora_removed + ), "remove lora monkey patch should restore the original outputs" + + def test_text_encoder_lora_scale(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + lora_images_with_scale = sd_pipe(**pipeline_inputs, cross_attention_kwargs={"scale": 0.5}).images + lora_image_with_scale_slice = lora_images_with_scale[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse( + torch.allclose(torch.from_numpy(lora_image_slice), torch.from_numpy(lora_image_with_scale_slice)) + ) + + def test_lora_unet_attn_processors(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.create_lora_weight_file(tmpdirname) + + pipeline_components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # check if vanilla attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0)) + + # load LoRA weight file + sd_pipe.load_lora_weights(tmpdirname) + + # check if lora attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsNotNone(module.to_q.lora_layer) + self.assertIsNotNone(module.to_k.lora_layer) + self.assertIsNotNone(module.to_v.lora_layer) + self.assertIsNotNone(module.to_out[0].lora_layer) + + def test_unload_lora_sd(self): + pipeline_components, lora_components = self.get_dummy_components() + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe.unet.set_default_attn_processor() + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images_two[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + + @unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU") + def test_lora_unet_attn_processors_with_xformers(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.create_lora_weight_file(tmpdirname) + + pipeline_components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # enable XFormers + sd_pipe.enable_xformers_memory_efficient_attention() + + # check if xFormers attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, XFormersAttnProcessor) + + # load LoRA weight file + sd_pipe.load_lora_weights(tmpdirname) + + # check if lora attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsNotNone(module.to_q.lora_layer) + self.assertIsNotNone(module.to_k.lora_layer) + self.assertIsNotNone(module.to_v.lora_layer) + self.assertIsNotNone(module.to_out[0].lora_layer) + + # unload lora weights + sd_pipe.unload_lora_weights() + + # check if attention processors are reverted back to xFormers + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, XFormersAttnProcessor) + + @unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU") + def test_lora_save_load_with_xformers(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + # enable XFormers + sd_pipe.enable_xformers_memory_efficient_attention() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + +@deprecate_after_peft_backend +class SDInpaintLoraMixinTests(unittest.TestCase): + lora_rank = 4 + + def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): + # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched + if output_pil: + # Get random floats in [0, 1] as image + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + mask_image = torch.ones_like(image) + # Convert image and mask_image to [0, 255] + image = 255 * image + mask_image = 255 * mask_image + # Convert to PIL image + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((img_res, img_res)) + mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB").resize((img_res, img_res)) + else: + # Get random floats in [0, 1] as image with spatial size (img_res, img_res) + image = floats_tensor((1, 3, img_res, img_res), rng=random.Random(seed)).to(device) + # Convert image to [-1, 1] + init_image = 2.0 * image - 1.0 + mask_image = torch.ones((1, 1, img_res, img_res), device=device) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": init_image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) + text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_lora_params = set_lora_weights( + text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1 + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + "image_encoder": None, + } + lora_components = { + "unet_lora_raw_params": unet_lora_raw_params, + "unet_lora_params": unet_lora_params, + "text_encoder_lora_params": text_encoder_lora_params, + } + return components, lora_components + + def test_stable_diffusion_inpaint_lora(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() + + # forward 1 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"]) + + # forward 2 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + + +@deprecate_after_peft_backend +class SDXLLoraLoaderMixinTests(unittest.TestCase): + lora_rank = 4 + + def get_dummy_components(self, modify_text_encoder=True): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + _, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank) + + if modify_text_encoder: + _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_lora_params = text_encoder_lora_state_dict(text_encoder) + StableDiffusionXLLoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder) + + _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + text_encoder_2, dtype=torch.float32, rank=self.lora_rank + ) + text_encoder_two_lora_params = text_encoder_lora_state_dict(text_encoder_2) + StableDiffusionXLLoraLoaderMixin._remove_text_encoder_monkey_patch_classmethod(text_encoder_2) + else: + text_encoder_lora_params = None + text_encoder_two_lora_params = None + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, + } + lora_components = { + "unet_lora_params": unet_lora_params, + "text_encoder_lora_params": text_encoder_lora_params, + "text_encoder_two_lora_params": text_encoder_two_lora_params, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_lora_save_load(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_unload_lora_sdxl(self): + pipeline_components, lora_components = self.get_dummy_components() + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe.unet.set_default_attn_processor() + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images_two[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + + def test_load_lora_locally(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + sd_pipe.unload_lora_weights() + + def test_text_encoder_lora_state_dict_unchanged(self): + pipeline_components, lora_components = self.get_dummy_components(modify_text_encoder=False) + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + + text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys()) + + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # Modify the text encoder. + _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + sd_pipe.text_encoder, dtype=torch.float32, rank=self.lora_rank + ) + lora_components["text_encoder_lora_params"] = set_lora_weights( + text_encoder_lora_state_dict(sd_pipe.text_encoder), randn_weight=True, var=0.1 + ) + _ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder( + sd_pipe.text_encoder_2, dtype=torch.float32, rank=self.lora_rank + ) + lora_components["text_encoder_two_lora_params"] = set_lora_weights( + text_encoder_lora_state_dict(sd_pipe.text_encoder_2), randn_weight=True, var=0.1 + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) + + text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) + + sd_pipe.unload_lora_weights() + + text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys()) + text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys()) + + # default & unloaded LoRA weights should have identical state_dicts + assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 + # default & loaded LoRA weights should NOT have identical state_dicts + assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 + + # default & unloaded LoRA weights should have identical state_dicts + assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 + # default & loaded LoRA weights should NOT have identical state_dicts + assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2 + + def test_load_lora_locally_safetensors(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.unload_lora_weights() + + def test_lora_fuse_nan(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float( + "NaN" + ) + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + sd_pipe.fuse_lora(safe_fusing=True) + + # without we should not see an error, but every image will be black + sd_pipe.fuse_lora(safe_fusing=False) + + out = sd_pipe("test", num_inference_steps=2, output_type="np").images + + assert np.isnan(out).all() + + def test_lora_fusion(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3)) + + def test_unfuse_lora(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Reverse LoRA fusion. + sd_pipe.unfuse_lora() + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "Fusion of LoRAs should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "Fusion of LoRAs should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + + def test_lora_fusion_is_not_affected_by_unloading(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1] + + assert ( + np.abs(lora_image_slice - images_with_unloaded_lora_slice).max() < 2e-1 + ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." + + def test_fuse_lora_with_different_scales(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora(lora_scale=1.0) + lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] + + # Reverse LoRA fusion. + sd_pipe.unfuse_lora() + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora(lora_scale=0.5) + lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + assert not np.allclose( + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + ), "Different LoRA scales should influence the outputs accordingly." + + def test_with_different_scales(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + original_imagee_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] + + lora_images_scale_0_5 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + lora_images_scale_0_0 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1] + + assert not np.allclose( + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + ), "Different LoRA scales should influence the outputs accordingly." + + assert np.allclose( + original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 + ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." + + def test_with_different_scales_fusion_equivalence(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + sd_pipe.unet.set_default_attn_processor() + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + images_slice = images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + lora_images_scale_0_5 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + sd_pipe.fuse_lora(lora_scale=0.5) + lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1] + + assert np.allclose( + lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03 + ), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale." + + sd_pipe.unfuse_lora() + images_unfused = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + images_slice_unfused = images_unfused[0, -3:, -3:, -1] + + assert np.allclose(images_slice, images_slice_unfused, atol=1e-03), "Unfused should match no LoRA" + + assert not np.allclose( + images_slice, lora_image_slice_scale_0_5, atol=1e-03 + ), "0.5 scale and no scale shouldn't match" + + def test_save_load_fused_lora_modules(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_params"], + text_encoder_lora_layers=lora_components["text_encoder_lora_params"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_fusion = lora_images_fusion[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + sd_pipe.save_pretrained(tmpdirname) + sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname).to(torch_device) + + loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images + loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1] + + assert np.allclose( + lora_image_slice_fusion, loaded_lora_image_slice, atol=1e-03 + ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." + + +@deprecate_after_peft_backend +class UNet2DConditionLoRAModelTests(unittest.TestCase): + model_class = UNet2DConditionModel + main_input_name = "sample" + lora_rank = 4 + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32), rng=random.Random(0)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_lora_at_different_scales(self): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) + + # make sure we can set a list of attention processors + model.load_attn_procs(lora_params) + model.to(torch_device) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 3e-3 + assert (sample3 - sample4).abs().max() < 3e-3 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 1e-4 + + def test_lora_on_off(self, expected_max_diff=1e-3): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + old_sample = model(**inputs_dict).sample + + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) + model.load_attn_procs(lora_params) + + with torch.no_grad(): + sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + + # Unload LoRA. + model.unload_lora() + + with torch.no_grad(): + new_sample = model(**inputs_dict).sample + + max_diff_new_sample = (sample - new_sample).abs().max() + max_diff_old_sample = (sample - old_sample).abs().max() + + assert max_diff_new_sample < expected_max_diff + assert max_diff_old_sample < expected_max_diff + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_lora_xformers_on_off(self, expected_max_diff=6e-4): + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + _, lora_params = create_unet_lora_layers(model, rank=self.lora_rank) + model.load_attn_procs(lora_params) + + # default + with torch.no_grad(): + sample = model(**inputs_dict).sample + + model.enable_xformers_memory_efficient_attention() + on_sample = model(**inputs_dict).sample + + model.disable_xformers_memory_efficient_attention() + off_sample = model(**inputs_dict).sample + + max_diff_on_sample = (sample - on_sample).abs().max() + max_diff_off_sample = (sample - off_sample).abs().max() + + assert max_diff_on_sample < expected_max_diff + assert max_diff_off_sample < expected_max_diff + + +@deprecate_after_peft_backend +class UNet3DConditionLoRAModelTests(unittest.TestCase): + model_class = UNet3DConditionModel + main_input_name = "sample" + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels, num_frames) + sizes, rng=random.Random(0)).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32), rng=random.Random(0)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 4, 32, 32) + + @property + def output_shape(self): + return (4, 4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ( + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_lora_at_different_scales(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = 8 + + model = self.model_class(**init_dict) + model.to(torch_device) + + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + unet_lora_params = create_3d_unet_lora_layers(model) + + # make sure we can set a list of attention processors + model.load_attn_procs(unet_lora_params) + model.to(torch_device) + + with torch.no_grad(): + sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample + sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample + + assert (sample1 - sample2).abs().max() < 3e-3 + assert (sample3 - sample4).abs().max() < 3e-3 + + # sample 2 and sample 3 should be different + assert (sample2 - sample3).abs().max() > 3e-3 + + +@slow +@deprecate_after_peft_backend +@require_torch_gpu +class LoraIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_dreambooth_old_format(self): + generator = torch.Generator("cpu").manual_seed(0) + + lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe( + "A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_dreambooth_text_encoder_new_format(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/lora-trained" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_a1111(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to( + torch_device + ) + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_lycoris(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16" + ).to(torch_device) + lora_model_id = "hf-internal-testing/edgLycorisMugler-light" + lora_filename = "edgLycorisMugler-light.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_a1111_with_model_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_a1111_with_sequential_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_kohya_sd_v15_with_higher_dimensions(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + lora_model_id = "hf-internal-testing/urushisato-lora" + lora_filename = "urushisato_v15.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.7165, 0.6616, 0.5833, 0.7504, 0.6718, 0.587, 0.6871, 0.6361, 0.5694]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_vanilla_funetuning(self): + generator = torch.Generator().manual_seed(0) + + lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4" + card = RepoCard.load(lora_model_id) + base_model_id = card.data.to_dict()["base_model"] + + pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None) + pipe = pipe.to(torch_device) + pipe.load_lora_weights(lora_model_id) + + images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + + expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_unload_kohya_lora(self): + generator = torch.manual_seed(0) + prompt = "masterpiece, best quality, mountain" + num_inference_steps = 2 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + initial_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + initial_images = initial_images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" + lora_filename = "Colored_Icons_by_vizsumit.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images = lora_images[0, -3:, -3:, -1].flatten() + + pipe.unload_lora_weights() + generator = torch.manual_seed(0) + unloaded_lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(initial_images, lora_images)) + self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + + def test_load_unload_load_kohya_lora(self): + # This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded + # without introducing any side-effects. Even though the test uses a Kohya-style + # LoRA, the underlying adapter handling mechanism is format-agnostic. + generator = torch.manual_seed(0) + prompt = "masterpiece, best quality, mountain" + num_inference_steps = 2 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + initial_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + initial_images = initial_images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" + lora_filename = "Colored_Icons_by_vizsumit.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images = lora_images[0, -3:, -3:, -1].flatten() + + pipe.unload_lora_weights() + generator = torch.manual_seed(0) + unloaded_lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(initial_images, lora_images)) + self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) + + # make sure we can load a LoRA again after unloading and they don't have + # any undesired effects. + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + generator = torch.manual_seed(0) + lora_images_again = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3)) + + def test_sdxl_0_9_lora_one(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") + lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora" + lora_filename = "daiton-xl-lora-test.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_sdxl_0_9_lora_two(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") + lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora" + lora_filename = "saijo.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_sdxl_0_9_lora_three(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") + lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora" + lora_filename = "kame_sdxl_v2-000020-16rank.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468]) + + self.assertTrue(np.allclose(images, expected, atol=5e-3)) + + def test_sdxl_1_0_lora(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_sdxl_1_0_lora_fusion(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + # This way we also test equivalence between LoRA fusion and the non-fusion behaviour. + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_sdxl_1_0_lora_unfusion(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_with_fusion = images[0, -3:, -3:, -1].flatten() + + pipe.unfuse_lora() + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_without_fusion = images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3)) + + def test_sdxl_1_0_lora_unfusion_effectivity(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + original_image_slice = images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + _ = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + pipe.unfuse_lora() + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_without_fusion_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3)) + + def test_sdxl_1_0_lora_fusion_efficiency(self): + generator = torch.Generator().manual_seed(0) + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + + start_time = time.time() + for _ in range(3): + pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + end_time = time.time() + elapsed_time_non_fusion = end_time - start_time + + del pipe + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16) + pipe.fuse_lora() + pipe.enable_model_cpu_offload() + + generator = torch.Generator().manual_seed(0) + start_time = time.time() + for _ in range(3): + pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + end_time = time.time() + elapsed_time_fusion = end_time - start_time + + self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) + + def test_sdxl_1_0_last_ben(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + lora_model_id = "TheLastBen/Papercut_SDXL" + lora_filename = "papercut.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_sdxl_1_0_fuse_unfuse_all(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) + text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) + text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) + unet_sd = copy.deepcopy(pipe.unet.state_dict()) + + pipe.load_lora_weights( + "davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors", torch_dtype=torch.float16 + ) + pipe.fuse_lora() + pipe.unload_lora_weights() + pipe.unfuse_lora() + + assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict()) + assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict()) + assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict()) + + def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_canny_lora(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") + + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet + ) + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") + pipe.enable_sequential_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "corgi" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images[0].shape == (768, 512, 3) + + original_image = images[0, -3:, -3:, -1].flatten() + expected_image = np.array([0.4574, 0.4461, 0.4435, 0.4462, 0.4396, 0.439, 0.4474, 0.4486, 0.4333]) + assert np.allclose(original_image, expected_image, atol=1e-04) + + @nightly + def test_sequential_fuse_unfuse(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + + # 1. round + pipe.load_lora_weights("Pclanglais/TintinIA") + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + image_slice = images[0, -3:, -3:, -1].flatten() + + pipe.unfuse_lora() + + # 2. round + pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style") + pipe.fuse_lora() + pipe.unfuse_lora() + + # 3. round + pipe.load_lora_weights("ostris/crayon_style_lora_sdxl") + pipe.fuse_lora() + pipe.unfuse_lora() + + # 4. back to 1st round + pipe.load_lora_weights("Pclanglais/TintinIA") + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + images_2 = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + image_slice_2 = images_2[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))