diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 5db13825c9eb..bccd37ddc42f 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -66,6 +66,7 @@ def text_encoder_attn_modules(text_encoder): "SD3LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "LoraLoaderMixin", + "FluxLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -83,6 +84,7 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, + FluxLoraLoaderMixin, LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 73273618956a..f612cc0c6e53 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1475,6 +1475,481 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t super().unfuse_lora(components=components) +class FluxLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`FluxTransformer2DModel`], + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + + Specific to [`StableDiffusion3Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "text_encoder"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index fd6c639a7cdf..89d6a28b14dd 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -32,6 +32,7 @@ "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, + "FluxTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 64cc0ae7b5ad..c1a7010d919a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import SD3LoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -137,7 +137,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): +class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): r""" The Flux pipeline for text-to-image generation. @@ -321,7 +321,7 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -354,12 +354,12 @@ def encode_prompt( ) if self.text_encoder is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py new file mode 100644 index 000000000000..c0f0684ac4de --- /dev/null +++ b/tests/lora/test_lora_layers_flux.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +import torch +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = FluxPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + uses_flow_matching = True + transformer_kwargs = { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "pooled_projection_dim": 32, + "axes_dims_rope": [4, 4, 8], + } + transformer_cls = FluxTransformer2DModel + vae_kwargs = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": (4,), + "layers_per_block": 1, + "latent_channels": 1, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, + } + has_two_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 8, 8, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 46b965ec33d9..0aee4f57c2c6 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.repocard import RepoCard from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer from diffusers import ( AutoPipelineForImage2Image, @@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "latent_channels": 4, } + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + + @property + def output_shape(self): + return (1, 64, 64, 3) def setUp(self): super().setUp() diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 9ce559be7f06..31c62f27a75a 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -15,10 +15,9 @@ import sys import unittest -from diffusers import ( - FlowMatchEulerDiscreteScheduler, - StableDiffusion3Pipeline, -) +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel + +from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device @@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} + uses_flow_matching = True transformer_kwargs = { "sample_size": 32, "patch_size": 1, @@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "pooled_projection_dim": 64, "out_channels": 4, } + transformer_cls = SD3Transformer2DModel vae_kwargs = { "sample_size": 32, "in_channels": 3, @@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "scaling_factor": 1.5035, } has_three_text_encoders = True + tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip" + tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip" + tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder" + text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2" + text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + + @property + def output_shape(self): + return (1, 32, 32, 3) @require_torch_gpu def test_sd3_lora(self): diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index f6ca4f304eb9..f00f7b193abf 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -22,6 +22,7 @@ import numpy as np import torch from packaging import version +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import ( ControlNetModel, @@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): "latent_channels": 4, "sample_size": 128, } + text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2" + tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2" + tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2" + + @property + def output_shape(self): + return (1, 64, 64, 3) def setUp(self): super().setUp() diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ca2e92832229..283b9f534766 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import tempfile import unittest @@ -19,14 +20,12 @@ import numpy as np import torch -from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, DDIMScheduler, FlowMatchEulerDiscreteScheduler, LCMScheduler, - SD3Transformer2DModel, UNet2DConditionModel, ) from diffusers.utils.import_utils import is_peft_available @@ -72,9 +71,19 @@ class PeftLoraLoaderMixinTests: pipeline_class = None scheduler_cls = None scheduler_kwargs = None + uses_flow_matching = False + has_two_text_encoders = False has_three_text_encoders = False + text_encoder_cls, text_encoder_id = None, None + text_encoder_2_cls, text_encoder_2_id = None, None + text_encoder_3_cls, text_encoder_3_id = None, None + tokenizer_cls, tokenizer_id = None, None + tokenizer_2_cls, tokenizer_2_id = None, None + tokenizer_3_cls, tokenizer_3_id = None, None + unet_kwargs = None + transformer_cls = None transformer_kwargs = None vae_kwargs = None @@ -91,28 +100,23 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): if self.unet_kwargs is not None: unet = UNet2DConditionModel(**self.unet_kwargs) else: - transformer = SD3Transformer2DModel(**self.transformer_kwargs) + transformer = self.transformer_cls(**self.transformer_kwargs) scheduler = scheduler_cls(**self.scheduler_kwargs) torch.manual_seed(0) vae = AutoencoderKL(**self.vae_kwargs) - if not self.has_three_text_encoders: - text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2") - tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) - if self.has_two_text_encoders: - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2") - tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") + if self.text_encoder_2_cls is not None: + text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id) + tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id) - if self.has_three_text_encoders: - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2") - text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + if self.text_encoder_3_cls is not None: + text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id) + tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id) text_lora_config = LoraConfig( r=rank, @@ -130,45 +134,39 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): use_dora=use_dora, ) - if self.has_two_text_encoders or self.has_three_text_encoders: - if self.unet_kwargs is not None: - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "image_encoder": None, - "feature_extractor": None, - } - elif self.has_three_text_encoders and self.transformer_kwargs is not None: - pipeline_components = { - "transformer": transformer, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "text_encoder_2": text_encoder_2, - "tokenizer_2": tokenizer_2, - "text_encoder_3": text_encoder_3, - "tokenizer_3": tokenizer_3, - } - else: - pipeline_components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, - "image_encoder": None, - } + pipeline_components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + # Denoiser + if self.unet_kwargs is not None: + pipeline_components.update({"unet": unet}) + elif self.transformer_kwargs is not None: + pipeline_components.update({"transformer": transformer}) + + # Remaining text encoders. + if self.text_encoder_2_cls is not None: + pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2}) + if self.text_encoder_3_cls is not None: + pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3}) + + # Remaining stuff + init_params = inspect.signature(self.pipeline_class.__init__).parameters + if "safety_checker" in init_params: + pipeline_components.update({"safety_checker": None}) + if "feature_extractor" in init_params: + pipeline_components.update({"feature_extractor": None}) + if "image_encoder" in init_params: + pipeline_components.update({"image_encoder": None}) return pipeline_components, text_lora_config, denoiser_lora_config + @property + def output_shape(self): + raise NotImplementedError + def get_dummy_inputs(self, with_generator=True): batch_size = 1 sequence_length = 10 @@ -205,9 +203,7 @@ def test_simple_inference(self): Tests a simple inference and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -217,8 +213,7 @@ def test_simple_inference(self): _, _, inputs = self.get_dummy_inputs() output_no_lora = pipe(**inputs).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) def test_simple_inference_with_text_lora(self): """ @@ -226,9 +221,7 @@ def test_simple_inference_with_text_lora(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -238,17 +231,18 @@ def test_simple_inference_with_text_lora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -261,9 +255,7 @@ def test_simple_inference_with_text_lora_and_scale(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -273,17 +265,18 @@ def test_simple_inference_with_text_lora_and_scale(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -322,9 +315,7 @@ def test_simple_inference_with_text_lora_fused(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -334,26 +325,27 @@ def test_simple_inference_with_text_lora_fused(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.fuse_lora() # Fusing should still keep the LoRA layers self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertFalse( @@ -366,9 +358,7 @@ def test_simple_inference_with_text_lora_unloaded(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -378,17 +368,18 @@ def test_simple_inference_with_text_lora_unloaded(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -397,10 +388,11 @@ def test_simple_inference_with_text_lora_unloaded(self): ) if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertFalse( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly unloaded in text encoder 2", + ) ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -413,9 +405,7 @@ def test_simple_inference_with_text_lora_save_load(self): Tests a simple usecase where users could use saving utilities for LoRA. """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -425,31 +415,32 @@ def test_simple_inference_with_text_lora_save_load(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images with tempfile.TemporaryDirectory() as tmpdirname: text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) if self.has_two_text_encoders or self.has_three_text_encoders: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - safe_serialization=False, - ) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + text_encoder_2_lora_layers=text_encoder_2_state_dict, + safe_serialization=False, + ) else: self.pipeline_class.save_lora_weights( save_directory=tmpdirname, @@ -457,6 +448,14 @@ def test_simple_inference_with_text_lora_save_load(self): safe_serialization=False, ) + if self.has_two_text_encoders: + if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules: + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + text_encoder_lora_layers=text_encoder_state_dict, + safe_serialization=False, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() @@ -466,9 +465,10 @@ def test_simple_inference_with_text_lora_save_load(self): self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), @@ -482,9 +482,7 @@ def test_simple_inference_with_partial_text_lora(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls) @@ -503,8 +501,7 @@ def test_simple_inference_with_partial_text_lora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") @@ -517,17 +514,18 @@ def test_simple_inference_with_partial_text_lora(self): } if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) - state_dict.update( - { - f"text_encoder_2.{module_name}": param - for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() - if "text_model.encoder.layers.4" not in module_name - } - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + state_dict.update( + { + f"text_encoder_2.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + if "text_model.encoder.layers.4" not in module_name + } + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -549,9 +547,7 @@ def test_simple_inference_save_pretrained(self): Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) @@ -561,17 +557,17 @@ def test_simple_inference_save_pretrained(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -587,10 +583,11 @@ def test_simple_inference_save_pretrained(self): ) if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), - "Lora not correctly set in text encoder 2", - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), + "Lora not correctly set in text encoder 2", + ) images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images @@ -604,14 +601,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -621,8 +614,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -635,10 +627,11 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -650,32 +643,23 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): else: denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) - if self.has_two_text_encoders or self.has_three_text_encoders: - text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + saving_kwargs = { + "save_directory": tmpdirname, + "text_encoder_lora_layers": text_encoder_state_dict, + "safe_serialization": False, + } - if self.unet_kwargs is not None: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - unet_lora_layers=denoiser_state_dict, - safe_serialization=False, - ) - else: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - text_encoder_2_lora_layers=text_encoder_2_state_dict, - transformer_lora_layers=denoiser_state_dict, - safe_serialization=False, - ) + if self.unet_kwargs is not None: + saving_kwargs.update({"unet_lora_layers": denoiser_state_dict}) else: - self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - text_encoder_lora_layers=text_encoder_state_dict, - unet_lora_layers=denoiser_state_dict, - safe_serialization=False, - ) + saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict}) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2) + saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict}) + + self.pipeline_class.save_lora_weights(**saving_kwargs) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) pipe.unload_lora_weights() @@ -688,9 +672,10 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) self.assertTrue( np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), @@ -703,9 +688,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -715,8 +698,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -728,10 +710,11 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -775,9 +758,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): and makes sure it works as expected - with unet """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -787,8 +768,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -801,10 +781,11 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.fuse_lora() # Fusing should still keep the LoRA layers @@ -813,9 +794,10 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertFalse( @@ -828,9 +810,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -840,8 +820,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -853,10 +832,11 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.unload_lora_weights() # unloading should remove the LoRA layers @@ -869,10 +849,11 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): ) if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertFalse( - check_if_lora_correctly_set(pipe.text_encoder_2), - "Lora not correctly unloaded in text encoder 2", - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertFalse( + check_if_lora_correctly_set(pipe.text_encoder_2), + "Lora not correctly unloaded in text encoder 2", + ) ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images self.assertTrue( @@ -886,9 +867,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -908,10 +887,11 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.fuse_lora() @@ -926,9 +906,10 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") if self.has_two_text_encoders or self.has_three_text_encoders: - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" + ) # Fuse and unfuse should lead to the same results self.assertTrue( @@ -942,9 +923,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): multiple adapters and set them """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -972,11 +951,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.set_adapters("adapter-1") @@ -1023,9 +1003,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): return scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1047,10 +1025,11 @@ def test_simple_inference_with_text_denoiser_block_scale(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) weights_1 = {"text_encoder": 2, "unet": {"down": 5}} pipe.set_adapters("adapter-1", weights_1) @@ -1090,9 +1069,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): return scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1120,11 +1097,12 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}} @@ -1170,7 +1148,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]: return def updown_options(blocks_with_tf, layers_per_block, value): @@ -1249,7 +1227,9 @@ def all_possible_dict_opts(unet, value): pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") for scale_dict in all_possible_dict_opts(pipe.unet, value=1234): # test if lora block scales can be set with this scale_dict @@ -1264,9 +1244,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): multiple adapters and set/delete them """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1294,11 +1272,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.set_adapters("adapter-1") @@ -1370,9 +1350,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): multiple adapters and set them """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1400,11 +1378,13 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) pipe.set_adapters("adapter-1") @@ -1453,9 +1433,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): @skip_mps def test_lora_fuse_nan(self): scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1501,9 +1479,7 @@ def test_get_adapters(self): are the expected results """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1539,9 +1515,7 @@ def test_get_list_adapters(self): are the expected results """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1607,9 +1581,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): and makes sure it works as expected - with unet and multi-adapter case """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1619,8 +1591,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") if self.unet_kwargs is not None: @@ -1640,11 +1611,13 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") - pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) # set them to multi-adapter inference mode pipe.set_adapters(["adapter-1", "adapter-2"]) @@ -1676,9 +1649,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self): @require_peft_version_greater(peft_version="0.9.0") def test_simple_inference_with_dora(self): scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components( @@ -1690,8 +1661,7 @@ def test_simple_inference_with_dora(self): _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images - shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3) - self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked) + self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe.text_encoder.add_adapter(text_lora_config) if self.unet_kwargs is not None: @@ -1704,10 +1674,12 @@ def test_simple_inference_with_dora(self): self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") if self.has_two_text_encoders or self.has_three_text_encoders: - pipe.text_encoder_2.add_adapter(text_lora_config) - self.assertTrue( - check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" - ) + lora_loadable_components = self.pipeline_class._lora_loadable_modules + if "text_encoder_2" in lora_loadable_components: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images @@ -1723,9 +1695,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): and makes sure it works as expected """ scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) @@ -1760,7 +1730,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): _ = pipe(**inputs, generator=torch.manual_seed(0)).images def test_modify_padding_mode(self): - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]: return def set_pad_mode(network, mode="circular"): @@ -1769,9 +1739,7 @@ def set_pad_mode(network, mode="circular"): module.padding_mode = mode scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] - if self.has_three_text_encoders and self.transformer_kwargs - else [DDIMScheduler, LCMScheduler] + [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) for scheduler_cls in scheduler_classes: components, _, _ = self.get_dummy_components(scheduler_cls)