diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index 58611a61c25d..093cc99972a3 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
+- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
@@ -56,6 +57,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
+## AuraFlowLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
## LTXVideoLoraLoaderMixin
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 3ba1bfacf3dd..7b440f6f4515 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder):
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
+ "AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
@@ -103,6 +104,7 @@ def text_encoder_attn_modules(text_encoder):
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
+ AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 2e241bc9ffad..aa508cf87f40 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1593,6 +1593,339 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
super().unfuse_lora(components=components, **kwargs)
+class AuraFlowLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`AuraFlowTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap : (`bool`, *optional*)
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
+
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
+ to call an additional method before loading the adapter:
+
+ ```py
+ pipeline = ... # load diffusers pipeline
+ max_rank = ... # the highest rank among all LoRAs that you want to load
+ # call *before* compiling and loading the LoRA adapter
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
+ pipeline.load_lora_weights(file_name)
+ # optionally compile the model now
+ ```
+
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
+ limitations to this technique, which are documented here:
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 9165c46f3c78..1d990e81458d 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -52,6 +52,7 @@
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
+ "AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py
index 4e9e0c07ca75..3b54303584bf 100644
--- a/src/diffusers/models/lora.py
+++ b/src/diffusers/models/lora.py
@@ -38,7 +38,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-def text_encoder_attn_modules(text_encoder):
+def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules
-def text_encoder_mlp_modules(text_encoder):
+def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index 4938ed23c506..8781424c61d3 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -13,15 +13,15 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import FromOriginalModelMixin
-from ...utils import logging
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
@@ -160,14 +160,20 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
residual = hidden_states
+ attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
- attn_output = self.attn(hidden_states=norm_hidden_states)
+ attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -223,10 +229,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
residual_context = encoder_hidden_states
+ attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -236,7 +247,9 @@ def forward(
# Attention.
attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ **attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -254,7 +267,7 @@ def forward(
return encoder_hidden_states, hidden_states
-class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -449,8 +462,24 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -474,7 +503,10 @@ def forward(
else:
encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ attention_kwargs=attention_kwargs,
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -491,7 +523,9 @@ def forward(
)
else:
- combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
+ combined_hidden_states = block(
+ hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
+ )
hidden_states = combined_hidden_states[:, encoder_seq_len:]
@@ -512,6 +546,10 @@ def forward(
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
if not return_dict:
return (output,)
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index ea60e66d2db9..7c98b3b71c48 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -12,17 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
+from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -112,7 +120,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class AuraFlowPipeline(DiffusionPipeline):
+class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
@@ -233,6 +241,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
+ lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -259,10 +268,20 @@ def encode_prompt(
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
if device is None:
device = self._execution_device
-
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -346,6 +365,11 @@ def encode_prompt(
negative_prompt_embeds = None
negative_prompt_attention_mask = None
+ if self.text_encoder is not None:
+ if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -403,6 +427,10 @@ def upcast_vae(self):
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
@property
def num_timesteps(self):
return self._num_timesteps
@@ -428,6 +456,7 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -486,6 +515,10 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -520,6 +553,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
@@ -530,6 +564,7 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -553,6 +588,7 @@ def __call__(
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -594,6 +630,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
+ attention_kwargs=self.attention_kwargs,
)[0]
# perform guidance
diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py
new file mode 100644
index 000000000000..ac1fed608cc8
--- /dev/null
+++ b/tests/lora/test_lora_layers_auraflow.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from diffusers import (
+ AuraFlowPipeline,
+ AuraFlowTransformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ is_peft_available,
+ require_peft_backend,
+)
+
+
+if is_peft_available():
+ pass
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = AuraFlowPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "sample_size": 64,
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_mmdit_layers": 1,
+ "num_single_dit_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "caption_projection_dim": 32,
+ "pos_embed_max_size": 64,
+ }
+ transformer_cls = AuraFlowTransformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 4,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "shift_factor": 0.0609,
+ "scaling_factor": 1.5035,
+ }
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+ denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 4,
+ "guidance_scale": 0.0,
+ "height": 8,
+ "width": 8,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in AuraFlow.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in AuraFlow.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in AuraFlow.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 27fef495a484..87a8fddfa583 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -104,6 +104,7 @@ class PeftLoraLoaderMixinTests:
vae_kwargs = None
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
+ denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
if self.unet_kwargs and self.transformer_kwargs:
@@ -157,7 +158,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ target_modules=self.denoiser_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
@@ -602,9 +603,9 @@ def test_simple_inference_with_partial_text_lora(self):
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
r=4,
- rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
+ rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
lora_alpha=4,
- target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=False,
)
@@ -1451,17 +1452,27 @@ def test_lora_fuse_nan(self):
].weight += float("inf")
else:
named_modules = [name for name, _ in pipe.transformer.named_modules()]
- tower_name = (
- "transformer_blocks"
- if any(name == "transformer_blocks" for name in named_modules)
- else "blocks"
- )
- transformer_tower = getattr(pipe.transformer, tower_name)
- has_attn1 = any("attn1" in name for name in named_modules)
- if has_attn1:
- transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
- else:
- transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+ possible_tower_names = [
+ "transformer_blocks",
+ "blocks",
+ "joint_transformer_blocks",
+ "single_transformer_blocks",
+ ]
+ filtered_tower_names = [
+ tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
+ ]
+ if len(filtered_tower_names) == 0:
+ reason = (
+ f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
+ )
+ raise ValueError(reason)
+ for tower_name in filtered_tower_names:
+ transformer_tower = getattr(pipe.transformer, tower_name)
+ has_attn1 = any("attn1" in name for name in named_modules)
+ if has_attn1:
+ transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
+ else:
+ transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
@@ -1908,7 +1919,7 @@ def test_lora_B_bias(self):
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
- if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
+ if any(k in name for k in self.denoiser_target_modules):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()