Skip to content

Commit 6b38c38

Browse files
committed
feat: lora support for Flux.
add tests fix imports major fixes.
1 parent 0e46067 commit 6b38c38

File tree

10 files changed

+882
-311
lines changed

10 files changed

+882
-311
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def text_encoder_attn_modules(text_encoder):
6666
"SD3LoraLoaderMixin",
6767
"StableDiffusionXLLoraLoaderMixin",
6868
"LoraLoaderMixin",
69+
"FluxLoraLoaderMixin",
6970
]
7071
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7172
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -83,6 +84,7 @@ def text_encoder_attn_modules(text_encoder):
8384
from .ip_adapter import IPAdapterMixin
8485
from .lora_pipeline import (
8586
AmusedLoraLoaderMixin,
87+
FluxLoraLoaderMixin,
8688
LoraLoaderMixin,
8789
SD3LoraLoaderMixin,
8890
StableDiffusionLoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 475 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"UNet2DConditionModel": _maybe_expand_lora_scales,
3333
"UNetMotionModel": _maybe_expand_lora_scales,
3434
"SD3Transformer2DModel": lambda model_cls, weights: weights,
35+
"FluxTransformer2DModel": lambda model_cls, weights: weights,
3536
}
3637

3738

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,6 @@ def forward(
373373
)
374374
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
375375

376-
print(f"{txt_ids.shape=}, {img_ids.shape=}")
377376
ids = torch.cat((txt_ids, img_ids), dim=1)
378377
image_rotary_emb = self.pos_embed(ids)
379378

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,10 @@
1717

1818
import numpy as np
1919
import torch
20-
from transformers import (
21-
CLIPTextModel,
22-
CLIPTokenizer,
23-
T5EncoderModel,
24-
T5TokenizerFast,
25-
)
20+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
2621

2722
from ...image_processor import VaeImageProcessor
28-
from ...loaders import SD3LoraLoaderMixin
23+
from ...loaders import FluxLoraLoaderMixin
2924
from ...models.autoencoders import AutoencoderKL
3025
from ...models.transformers import FluxTransformer2DModel
3126
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,7 +137,7 @@ def retrieve_timesteps(
142137
return timesteps, num_inference_steps
143138

144139

145-
class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
140+
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
146141
r"""
147142
The Flux pipeline for text-to-image generation.
148143
@@ -333,7 +328,7 @@ def encode_prompt(
333328

334329
# set lora scale so that monkey patched LoRA
335330
# function of text encoder can correctly access it
336-
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
331+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
337332
self._lora_scale = lora_scale
338333

339334
# dynamically adjust the LoRA scale
@@ -366,12 +361,12 @@ def encode_prompt(
366361
)
367362

368363
if self.text_encoder is not None:
369-
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
364+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
370365
# Retrieve the original scale by scaling back the LoRA layers
371366
unscale_lora_layers(self.text_encoder, lora_scale)
372367

373368
if self.text_encoder_2 is not None:
374-
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
369+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
375370
# Retrieve the original scale by scaling back the LoRA layers
376371
unscale_lora_layers(self.text_encoder_2, lora_scale)
377372

tests/lora/test_lora_layers_flux.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import sys
16+
import unittest
17+
18+
import torch
19+
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
20+
21+
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
22+
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend
23+
24+
25+
if is_peft_available():
26+
pass
27+
28+
sys.path.append(".")
29+
30+
from utils import PeftLoraLoaderMixinTests # noqa: E402
31+
32+
33+
@require_peft_backend
34+
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
35+
pipeline_class = FluxPipeline
36+
scheduler_cls = FlowMatchEulerDiscreteScheduler()
37+
scheduler_kwargs = {}
38+
uses_flow_matching = True
39+
transformer_kwargs = {
40+
"patch_size": 1,
41+
"in_channels": 4,
42+
"num_layers": 1,
43+
"num_single_layers": 1,
44+
"attention_head_dim": 16,
45+
"num_attention_heads": 2,
46+
"joint_attention_dim": 32,
47+
"pooled_projection_dim": 32,
48+
"axes_dims_rope": [4, 4, 8],
49+
}
50+
transformer_cls = FluxTransformer2DModel
51+
vae_kwargs = {
52+
"sample_size": 32,
53+
"in_channels": 3,
54+
"out_channels": 3,
55+
"block_out_channels": (4,),
56+
"layers_per_block": 1,
57+
"latent_channels": 1,
58+
"norm_num_groups": 1,
59+
"use_quant_conv": False,
60+
"use_post_quant_conv": False,
61+
"shift_factor": 0.0609,
62+
"scaling_factor": 1.5035,
63+
}
64+
has_two_text_encoders = True
65+
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
66+
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
67+
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
68+
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
69+
70+
@property
71+
def output_shape(self):
72+
return (1, 8, 8, 3)
73+
74+
def get_dummy_inputs(self, with_generator=True):
75+
batch_size = 1
76+
sequence_length = 10
77+
num_channels = 4
78+
sizes = (32, 32)
79+
80+
generator = torch.manual_seed(0)
81+
noise = floats_tensor((batch_size, num_channels) + sizes)
82+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
83+
84+
pipeline_inputs = {
85+
"prompt": "A painting of a squirrel eating a burger",
86+
"num_inference_steps": 4,
87+
"guidance_scale": 0.0,
88+
"height": 8,
89+
"width": 8,
90+
"output_type": "np",
91+
}
92+
if with_generator:
93+
pipeline_inputs.update({"generator": generator})
94+
95+
return noise, input_ids, pipeline_inputs

tests/lora/test_lora_layers_sd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from huggingface_hub import hf_hub_download
2323
from huggingface_hub.repocard import RepoCard
2424
from safetensors.torch import load_file
25+
from transformers import CLIPTextModel, CLIPTokenizer
2526

2627
from diffusers import (
2728
AutoPipelineForImage2Image,
@@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
8081
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
8182
"latent_channels": 4,
8283
}
84+
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
85+
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
86+
87+
@property
88+
def output_shape(self):
89+
return (1, 64, 64, 3)
8390

8491
def setUp(self):
8592
super().setUp()

tests/lora/test_lora_layers_sd3.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
import sys
1616
import unittest
1717

18-
from diffusers import (
19-
FlowMatchEulerDiscreteScheduler,
20-
StableDiffusion3Pipeline,
21-
)
18+
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
19+
20+
from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
2221
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
2322

2423

@@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
3534
pipeline_class = StableDiffusion3Pipeline
3635
scheduler_cls = FlowMatchEulerDiscreteScheduler()
3736
scheduler_kwargs = {}
37+
uses_flow_matching = True
3838
transformer_kwargs = {
3939
"sample_size": 32,
4040
"patch_size": 1,
@@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4747
"pooled_projection_dim": 64,
4848
"out_channels": 4,
4949
}
50+
transformer_cls = SD3Transformer2DModel
5051
vae_kwargs = {
5152
"sample_size": 32,
5253
"in_channels": 3,
@@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
6162
"scaling_factor": 1.5035,
6263
}
6364
has_three_text_encoders = True
65+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
66+
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
67+
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
68+
text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder")
69+
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2")
70+
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
71+
72+
@property
73+
def output_shape(self):
74+
return (1, 32, 32, 3)
6475

6576
@require_torch_gpu
6677
def test_sd3_lora(self):

tests/lora/test_lora_layers_sdxl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import torch
2424
from packaging import version
25+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
2526

2627
from diffusers import (
2728
ControlNetModel,
@@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
8990
"latent_channels": 4,
9091
"sample_size": 128,
9192
}
93+
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
94+
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
95+
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2")
96+
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
97+
98+
@property
99+
def output_shape(self):
100+
return (1, 64, 64, 3)
92101

93102
def setUp(self):
94103
super().setUp()

0 commit comments

Comments
 (0)