diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 5db13825c9eb..bccd37ddc42f 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -66,6 +66,7 @@ def text_encoder_attn_modules(text_encoder):
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin",
+ "FluxLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -83,6 +84,7 @@ def text_encoder_attn_modules(text_encoder):
from .ip_adapter import IPAdapterMixin
from .lora_pipeline import (
AmusedLoraLoaderMixin,
+ FluxLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 73273618956a..f612cc0c6e53 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1475,6 +1475,481 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
super().unfuse_lora(components=components)
+class FluxLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`FluxTransformer2DModel`],
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+
+ Specific to [`StableDiffusion3Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "text_encoder"]
+ transformer_name = TRANSFORMER_NAME
+ text_encoder_name = TEXT_ENCODER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = cls._fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
+ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SD3Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+ keys = list(state_dict.keys())
+
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
+ state_dict = {
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
+ }
+
+ if len(state_dict.keys()) > 0:
+ # check with first key if is not in peft format
+ first_key = next(iter(state_dict.keys()))
+ if "lora_A" not in first_key:
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
+
+ if adapter_name in getattr(transformer, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
+ )
+
+ rank = {}
+ for key, val in state_dict.items():
+ if "lora_B" in key:
+ rank[key] = val.shape[1]
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(transformer)
+
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
+ # otherwise loading LoRA weights will lead to an error
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in text_encoder_lora_state_dict:
+ continue
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not (transformer_lora_layers or text_encoder_lora_layers):
+ raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if text_encoder_lora_layers:
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer", "text_encoder"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ """
+ super().unfuse_lora(components=components)
+
+
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index fd6c639a7cdf..89d6a28b14dd 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -32,6 +32,7 @@
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
+ "FluxTransformer2DModel": lambda model_cls, weights: weights,
}
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 64cc0ae7b5ad..c1a7010d919a 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -20,7 +20,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
-from ...loaders import SD3LoraLoaderMixin
+from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -137,7 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
+class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.
@@ -321,7 +321,7 @@ def encode_prompt(
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
- if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
@@ -354,12 +354,12 @@ def encode_prompt(
)
if self.text_encoder is not None:
- if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
- if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
new file mode 100644
index 000000000000..c0f0684ac4de
--- /dev/null
+++ b/tests/lora/test_lora_layers_flux.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
+from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = FluxPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler()
+ scheduler_kwargs = {}
+ uses_flow_matching = True
+ transformer_kwargs = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "pooled_projection_dim": 32,
+ "axes_dims_rope": [4, 4, 8],
+ }
+ transformer_cls = FluxTransformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 1,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "shift_factor": 0.0609,
+ "scaling_factor": 1.5035,
+ }
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
+ tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
+ text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 4,
+ "guidance_scale": 0.0,
+ "height": 8,
+ "width": 8,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index 46b965ec33d9..0aee4f57c2c6 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -22,6 +22,7 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from safetensors.torch import load_file
+from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoPipelineForImage2Image,
@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
+ text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
+ tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
+
+ @property
+ def output_shape(self):
+ return (1, 64, 64, 3)
def setUp(self):
super().setUp()
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 9ce559be7f06..31c62f27a75a 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -15,10 +15,9 @@
import sys
import unittest
-from diffusers import (
- FlowMatchEulerDiscreteScheduler,
- StableDiffusion3Pipeline,
-)
+from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
+
+from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
+ uses_flow_matching = True
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"pooled_projection_dim": 64,
"out_channels": 4,
}
+ transformer_cls = SD3Transformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"scaling_factor": 1.5035,
}
has_three_text_encoders = True
+ tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
+ tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
+ tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder"
+ text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2"
+ text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ @property
+ def output_shape(self):
+ return (1, 32, 32, 3)
@require_torch_gpu
def test_sd3_lora(self):
diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py
index f6ca4f304eb9..f00f7b193abf 100644
--- a/tests/lora/test_lora_layers_sdxl.py
+++ b/tests/lora/test_lora_layers_sdxl.py
@@ -22,6 +22,7 @@
import numpy as np
import torch
from packaging import version
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
ControlNetModel,
@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4,
"sample_size": 128,
}
+ text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
+ tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
+ text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2"
+ tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
+
+ @property
+ def output_shape(self):
+ return (1, 64, 64, 3)
def setUp(self):
super().setUp()
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index ca2e92832229..283b9f534766 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import os
import tempfile
import unittest
@@ -19,14 +20,12 @@
import numpy as np
import torch
-from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
DDIMScheduler,
FlowMatchEulerDiscreteScheduler,
LCMScheduler,
- SD3Transformer2DModel,
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_peft_available
@@ -72,9 +71,19 @@ class PeftLoraLoaderMixinTests:
pipeline_class = None
scheduler_cls = None
scheduler_kwargs = None
+ uses_flow_matching = False
+
has_two_text_encoders = False
has_three_text_encoders = False
+ text_encoder_cls, text_encoder_id = None, None
+ text_encoder_2_cls, text_encoder_2_id = None, None
+ text_encoder_3_cls, text_encoder_3_id = None, None
+ tokenizer_cls, tokenizer_id = None, None
+ tokenizer_2_cls, tokenizer_2_id = None, None
+ tokenizer_3_cls, tokenizer_3_id = None, None
+
unet_kwargs = None
+ transformer_cls = None
transformer_kwargs = None
vae_kwargs = None
@@ -91,28 +100,23 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
if self.unet_kwargs is not None:
unet = UNet2DConditionModel(**self.unet_kwargs)
else:
- transformer = SD3Transformer2DModel(**self.transformer_kwargs)
+ transformer = self.transformer_cls(**self.transformer_kwargs)
scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs)
- if not self.has_three_text_encoders:
- text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
- tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
+ text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
+ tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
- if self.has_two_text_encoders:
- text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2")
- tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
+ if self.text_encoder_2_cls is not None:
+ text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
+ tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
- if self.has_three_text_encoders:
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
- text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder")
- text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2")
- text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ if self.text_encoder_3_cls is not None:
+ text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
+ tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
text_lora_config = LoraConfig(
r=rank,
@@ -130,45 +134,39 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
use_dora=use_dora,
)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if self.unet_kwargs is not None:
- pipeline_components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
- }
- elif self.has_three_text_encoders and self.transformer_kwargs is not None:
- pipeline_components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "text_encoder_3": text_encoder_3,
- "tokenizer_3": tokenizer_3,
- }
- else:
- pipeline_components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
+ pipeline_components = {
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ # Denoiser
+ if self.unet_kwargs is not None:
+ pipeline_components.update({"unet": unet})
+ elif self.transformer_kwargs is not None:
+ pipeline_components.update({"transformer": transformer})
+
+ # Remaining text encoders.
+ if self.text_encoder_2_cls is not None:
+ pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
+ if self.text_encoder_3_cls is not None:
+ pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})
+
+ # Remaining stuff
+ init_params = inspect.signature(self.pipeline_class.__init__).parameters
+ if "safety_checker" in init_params:
+ pipeline_components.update({"safety_checker": None})
+ if "feature_extractor" in init_params:
+ pipeline_components.update({"feature_extractor": None})
+ if "image_encoder" in init_params:
+ pipeline_components.update({"image_encoder": None})
return pipeline_components, text_lora_config, denoiser_lora_config
+ @property
+ def output_shape(self):
+ raise NotImplementedError
+
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
@@ -205,9 +203,7 @@ def test_simple_inference(self):
Tests a simple inference and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -217,8 +213,7 @@ def test_simple_inference(self):
_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
def test_simple_inference_with_text_lora(self):
"""
@@ -226,9 +221,7 @@ def test_simple_inference_with_text_lora(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -238,17 +231,18 @@ def test_simple_inference_with_text_lora(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
@@ -261,9 +255,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -273,17 +265,18 @@ def test_simple_inference_with_text_lora_and_scale(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
@@ -322,9 +315,7 @@ def test_simple_inference_with_text_lora_fused(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -334,26 +325,27 @@ def test_simple_inference_with_text_lora_fused(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
@@ -366,9 +358,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -378,17 +368,18 @@ def test_simple_inference_with_text_lora_unloaded(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -397,10 +388,11 @@ def test_simple_inference_with_text_lora_unloaded(self):
)
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertFalse(
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
@@ -413,9 +405,7 @@ def test_simple_inference_with_text_lora_save_load(self):
Tests a simple usecase where users could use saving utilities for LoRA.
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -425,31 +415,32 @@ def test_simple_inference_with_text_lora_save_load(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders:
- text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- safe_serialization=False,
- )
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ text_encoder_2_lora_layers=text_encoder_2_state_dict,
+ safe_serialization=False,
+ )
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
@@ -457,6 +448,14 @@ def test_simple_inference_with_text_lora_save_load(self):
safe_serialization=False,
)
+ if self.has_two_text_encoders:
+ if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ safe_serialization=False,
+ )
+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
@@ -466,9 +465,10 @@ def test_simple_inference_with_text_lora_save_load(self):
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
@@ -482,9 +482,7 @@ def test_simple_inference_with_partial_text_lora(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
@@ -503,8 +501,7 @@ def test_simple_inference_with_partial_text_lora(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -517,17 +514,18 @@ def test_simple_inference_with_partial_text_lora(self):
}
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
- state_dict.update(
- {
- f"text_encoder_2.{module_name}": param
- for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
- if "text_model.encoder.layers.4" not in module_name
- }
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+ state_dict.update(
+ {
+ f"text_encoder_2.{module_name}": param
+ for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
+ if "text_model.encoder.layers.4" not in module_name
+ }
+ )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
@@ -549,9 +547,7 @@ def test_simple_inference_save_pretrained(self):
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -561,17 +557,17 @@ def test_simple_inference_save_pretrained(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
@@ -587,10 +583,11 @@ def test_simple_inference_save_pretrained(self):
)
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertTrue(
- check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
- "Lora not correctly set in text encoder 2",
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
+ )
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
@@ -604,14 +601,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -621,8 +614,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
@@ -635,10 +627,11 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
@@ -650,32 +643,23 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
else:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
+ saving_kwargs = {
+ "save_directory": tmpdirname,
+ "text_encoder_lora_layers": text_encoder_state_dict,
+ "safe_serialization": False,
+ }
- if self.unet_kwargs is not None:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- unet_lora_layers=denoiser_state_dict,
- safe_serialization=False,
- )
- else:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- transformer_lora_layers=denoiser_state_dict,
- safe_serialization=False,
- )
+ if self.unet_kwargs is not None:
+ saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
else:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- unet_lora_layers=denoiser_state_dict,
- safe_serialization=False,
- )
+ saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
+ saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})
+
+ self.pipeline_class.save_lora_weights(**saving_kwargs)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
@@ -688,9 +672,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
@@ -703,9 +688,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -715,8 +698,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
@@ -728,10 +710,11 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
@@ -775,9 +758,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
and makes sure it works as expected - with unet
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -787,8 +768,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
@@ -801,10 +781,11 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
@@ -813,9 +794,10 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
@@ -828,9 +810,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -840,8 +820,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
@@ -853,10 +832,11 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -869,10 +849,11 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
)
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertFalse(
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
@@ -886,9 +867,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -908,10 +887,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.fuse_lora()
@@ -926,9 +906,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ )
# Fuse and unfuse should lead to the same results
self.assertTrue(
@@ -942,9 +923,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
multiple adapters and set them
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -972,11 +951,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.set_adapters("adapter-1")
@@ -1023,9 +1003,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
return
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1047,10 +1025,11 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1)
@@ -1090,9 +1069,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
return
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1120,11 +1097,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}}
@@ -1170,7 +1148,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
- if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
+ if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]:
return
def updown_options(blocks_with_tf, layers_per_block, value):
@@ -1249,7 +1227,9 @@ def all_possible_dict_opts(unet, value):
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
# test if lora block scales can be set with this scale_dict
@@ -1264,9 +1244,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
multiple adapters and set/delete them
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1294,11 +1272,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.set_adapters("adapter-1")
@@ -1370,9 +1350,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
multiple adapters and set them
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1400,11 +1378,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
pipe.set_adapters("adapter-1")
@@ -1453,9 +1433,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
@skip_mps
def test_lora_fuse_nan(self):
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1501,9 +1479,7 @@ def test_get_adapters(self):
are the expected results
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1539,9 +1515,7 @@ def test_get_list_adapters(self):
are the expected results
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1607,9 +1581,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
and makes sure it works as expected - with unet and multi-adapter case
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1619,8 +1591,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
@@ -1640,11 +1611,13 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
@@ -1676,9 +1649,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
@@ -1690,8 +1661,7 @@ def test_simple_inference_with_dora(self):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
- self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked)
+ self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
@@ -1704,10 +1674,12 @@ def test_simple_inference_with_dora(self):
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
@@ -1723,9 +1695,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
and makes sure it works as expected
"""
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1760,7 +1730,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
def test_modify_padding_mode(self):
- if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
+ if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]:
return
def set_pad_mode(network, mode="circular"):
@@ -1769,9 +1739,7 @@ def set_pad_mode(network, mode="circular"):
module.padding_mode = mode
scheduler_classes = (
- [FlowMatchEulerDiscreteScheduler]
- if self.has_three_text_encoders and self.transformer_kwargs
- else [DDIMScheduler, LCMScheduler]
+ [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)