Skip to content

[LoRA] Add LoRA support to AuraFlow #10216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 70 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
e50a108
Add AuraFlowLoraLoaderMixin
Warlord-K Jul 30, 2024
658d058
Add comments, remove qkv fusion
Warlord-K Jul 30, 2024
4208d09
Add Tests
Warlord-K Jul 30, 2024
98b19f6
Add AuraFlowLoraLoaderMixin to documentation
Warlord-K Jul 30, 2024
71f8bac
Add Suggested changes
Warlord-K Aug 11, 2024
0eee03e
Change attention_kwargs->joint_attention_kwargs
Warlord-K Aug 12, 2024
4e4f780
Rebasing derp.
hameerabbasi Dec 13, 2024
c07d1f5
fix
hlky Dec 13, 2024
1b7f99f
fix
hlky Dec 13, 2024
875a3e0
Quality fixes.
hameerabbasi Dec 13, 2024
a242d7a
make style
hlky Dec 13, 2024
a73df6b
`make fix-copies`
hameerabbasi Dec 13, 2024
894eac0
`ruff check --fix`
hameerabbasi Dec 13, 2024
2b36416
Attept 1 to fix tests.
hameerabbasi Dec 15, 2024
6b762b8
Attept 2 to fix tests.
hameerabbasi Dec 15, 2024
bc2a466
Attept 3 to fix tests.
hameerabbasi Dec 15, 2024
1c79095
Address review comments.
hameerabbasi Dec 19, 2024
9454e84
Rebasing derp.
hameerabbasi Dec 19, 2024
5700e52
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 3, 2025
6da81f8
Merge branch 'main' into auraflow-lora
sayakpaul Jan 6, 2025
28a4918
Get more tests passing by copying from Flux. Address review comments.
hameerabbasi Jan 7, 2025
d6028cd
`joint_attention_kwargs`->`attention_kwargs`
hameerabbasi Jan 7, 2025
6e899a3
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 7, 2025
2d02c2c
Add `lora_scale` property for te LoRAs.
hameerabbasi Jan 7, 2025
2b934b4
Make test better.
hameerabbasi Jan 7, 2025
532013f
Remove useless property.
hameerabbasi Jan 7, 2025
0ea9ecd
Merge branch 'main' into auraflow-lora
hlky Jan 8, 2025
e06d8eb
Skip TE-only tests for AuraFlow.
hameerabbasi Jan 8, 2025
2b35909
Support LoRA for non-CLIP TEs.
hameerabbasi Jan 10, 2025
1ec07a1
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Jan 10, 2025
077a452
Merge branch 'main' into auraflow-lora
hlky Jan 10, 2025
3095644
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 13, 2025
df28362
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Jan 19, 2025
7e63330
Restore LoRA tests.
hameerabbasi Jan 19, 2025
5620384
Undo adding LoRA support for non-CLIP TEs.
hameerabbasi Jan 19, 2025
cd691d3
Undo support for TE in AuraFlow LoRA.
hameerabbasi Jan 19, 2025
0fa5cd5
`make fix-copies`
hameerabbasi Jan 19, 2025
83e0825
Sync with upstream changes.
hameerabbasi Jan 19, 2025
12fbd11
Remove unneeded stuff.
hameerabbasi Jan 19, 2025
c602749
Merge branch 'main' into auraflow-lora
hameerabbasi Feb 26, 2025
cdd184d
Mirror `Lumina2`.
hameerabbasi Feb 26, 2025
ce1939b
Skip for MPS.
hameerabbasi Feb 26, 2025
3b9e655
Address review comments.
hameerabbasi Feb 26, 2025
c11b14d
Remove duplicated code.
hameerabbasi Feb 27, 2025
636f01c
Remove unnecessary code.
hameerabbasi Feb 27, 2025
75ba7da
Remove repeated docs.
hameerabbasi Mar 5, 2025
c2daa8a
Propagate attention.
hameerabbasi Mar 5, 2025
8aa2d69
Fix TE target modules.
hameerabbasi Mar 6, 2025
b19942f
MPS fix for LoRA tests.
hameerabbasi Mar 6, 2025
5091757
Unrelated TE LoRA tests fix.
hameerabbasi Mar 6, 2025
dee9074
Fix AuraFlow LoRA tests by applying to the right denoiser layers.
hameerabbasi Mar 26, 2025
6241109
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Mar 26, 2025
ed33194
Merge branch 'main' into auraflow-lora
sayakpaul Apr 8, 2025
65a3bf5
Apply style fixes
github-actions[bot] Apr 8, 2025
147a356
empty commit
sayakpaul Apr 8, 2025
0c91c1a
Fix the repo consistency issues.
hameerabbasi Apr 8, 2025
e97a83e
Remove unrelated changes.
hameerabbasi Apr 8, 2025
a5b78d1
Style.
hameerabbasi Apr 8, 2025
dbc8427
Fix `test_lora_fuse_nan`.
hameerabbasi Apr 8, 2025
22fc9d9
Merge branch 'main' into auraflow-lora
sayakpaul Apr 8, 2025
ea14465
fix quality issues.
sayakpaul Apr 8, 2025
a20d03d
`pytest.xfail` -> `ValueError`.
hameerabbasi Apr 8, 2025
fb5f5f7
Add back `skip_mps`.
hameerabbasi Apr 8, 2025
f88503b
Merge branch 'main' into auraflow-lora
sayakpaul Apr 8, 2025
12dc911
Apply style fixes
github-actions[bot] Apr 8, 2025
fd9ed52
Merge branch 'main' into auraflow-lora
sayakpaul Apr 9, 2025
e418c2f
Merge branch 'main' into auraflow-lora
sayakpaul Apr 10, 2025
5e537d1
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Apr 11, 2025
bc93160
`make fix-copies`
hameerabbasi Apr 11, 2025
2880ba4
Merge branch 'main' into auraflow-lora
sayakpaul Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
Expand Down Expand Up @@ -56,6 +57,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin

## LTXVideoLoraLoaderMixin

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder):
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
Expand Down Expand Up @@ -103,6 +104,7 @@ def text_encoder_attn_modules(text_encoder):
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
Expand Down
333 changes: 333 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def text_encoder_attn_modules(text_encoder):
def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
Expand All @@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules


def text_encoder_mlp_modules(text_encoder):
def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
Expand Down
58 changes: 48 additions & 10 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.


from typing import Dict, Union
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
Expand Down Expand Up @@ -160,14 +160,20 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)

def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
attention_kwargs = attention_kwargs or {}

# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states)
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)

# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
Expand Down Expand Up @@ -223,10 +229,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)

def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}

# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
Expand All @@ -236,7 +247,9 @@ def forward(

# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**attention_kwargs,
)

# Process attention outputs for the `hidden_states`.
Expand All @@ -254,7 +267,7 @@ def forward(
return encoder_hidden_states, hidden_states


class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).

Expand Down Expand Up @@ -449,8 +462,24 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)

height, width = hidden_states.shape[-2:]

# Apply patch embedding, timestep embedding, and project the caption embeddings.
Expand All @@ -474,7 +503,10 @@ def forward(

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
attention_kwargs=attention_kwargs,
)

# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
Expand All @@ -491,7 +523,9 @@ def forward(
)

else:
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
combined_hidden_states = block(
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
)

hidden_states = combined_hidden_states[:, encoder_seq_len:]

Expand All @@ -512,6 +546,10 @@ def forward(
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)

Expand Down
45 changes: 41 additions & 4 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import T5Tokenizer, UMT5EncoderModel

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput

Expand Down Expand Up @@ -112,7 +120,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class AuraFlowPipeline(DiffusionPipeline):
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
Expand Down Expand Up @@ -233,6 +241,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand All @@ -259,10 +268,20 @@ def encode_prompt(
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)

if device is None:
device = self._execution_device

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -346,6 +365,11 @@ def encode_prompt(
negative_prompt_embeds = None
negative_prompt_attention_mask = None

if self.text_encoder is not None:
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
Expand Down Expand Up @@ -403,6 +427,10 @@ def upcast_vae(self):
def guidance_scale(self):
return self._guidance_scale

@property
def attention_kwargs(self):
return self._attention_kwargs

@property
def num_timesteps(self):
return self._num_timesteps
Expand All @@ -428,6 +456,7 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
Expand Down Expand Up @@ -486,6 +515,10 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
Expand Down Expand Up @@ -520,6 +553,7 @@ def __call__(
)

self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs

# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
Expand All @@ -530,6 +564,7 @@ def __call__(
batch_size = prompt_embeds.shape[0]

device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand All @@ -553,6 +588,7 @@ def __call__(
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
Expand Down Expand Up @@ -594,6 +630,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]

# perform guidance
Expand Down
Loading
Loading