Skip to content

[FLUX] support LoRA #9057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def text_encoder_attn_modules(text_encoder):
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
Expand All @@ -83,6 +84,7 @@ def text_encoder_attn_modules(text_encoder):
from .ip_adapter import IPAdapterMixin
from .lora_pipeline import (
AmusedLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
Expand Down
475 changes: 475 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from ...image_processor import VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -137,7 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.

Expand Down Expand Up @@ -321,7 +321,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
Expand Down Expand Up @@ -354,12 +354,12 @@ def encode_prompt(
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)

Expand Down
92 changes: 92 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest

import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend


sys.path.append(".")

from utils import PeftLoraLoaderMixinTests # noqa: E402


@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
transformer_cls = FluxTransformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 1,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
has_two_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"

@property
def output_shape(self):
return (1, 8, 8, 3)

def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)

generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)

pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})

return noise, input_ids, pipeline_inputs
7 changes: 7 additions & 0 deletions tests/lora/test_lora_layers_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer

from diffusers import (
AutoPipelineForImage2Image,
Expand Down Expand Up @@ -80,6 +81,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"

@property
def output_shape(self):
return (1, 64, 64, 3)

def setUp(self):
super().setUp()
Expand Down
19 changes: 15 additions & 4 deletions tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import sys
import unittest

from diffusers import (
FlowMatchEulerDiscreteScheduler,
StableDiffusion3Pipeline,
)
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel

from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device


Expand All @@ -35,6 +34,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
Expand All @@ -47,6 +47,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"pooled_projection_dim": 64,
"out_channels": 4,
}
transformer_cls = SD3Transformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
Expand All @@ -61,6 +62,16 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"scaling_factor": 1.5035,
}
has_three_text_encoders = True
tokenizer_cls, tokenizer_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "hf-internal-testing/tiny-random-clip"
tokenizer_3_cls, tokenizer_3_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "hf-internal-testing/tiny-sd3-text_encoder-2"
text_encoder_3_cls, text_encoder_3_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"

@property
def output_shape(self):
return (1, 32, 32, 3)

@require_torch_gpu
def test_sd3_lora(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import torch
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from diffusers import (
ControlNetModel,
Expand Down Expand Up @@ -89,6 +90,14 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4,
"sample_size": 128,
}
text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
text_encoder_2_cls, text_encoder_2_id = CLIPTextModelWithProjection, "peft-internal-testing/tiny-clip-text-2"
tokenizer_2_cls, tokenizer_2_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"

@property
def output_shape(self):
return (1, 64, 64, 3)

def setUp(self):
super().setUp()
Expand Down
Loading
Loading