-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Flux] Dreambooth LoRA training scripts #9086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
fb5488d
initial commit - dreambooth for flux
linoytsaban 3062daf
update transformer to be FluxTransformer2DModel
linoytsaban 1448aa1
update training loop and validation inference
linoytsaban a59e012
fix sd3->flux docs
linoytsaban b225283
Merge branch 'huggingface:main' into flux-fine-tuning
linoytsaban df074d4
add guidance handling, not sure if it makes sense(?)
linoytsaban 4f7c88c
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban c897c31
inital dreambooth lora commit
linoytsaban c3c38a4
fix text_ids in compute_text_embeddings
linoytsaban 5e92ba9
Merge branch 'main' into flux-fine-tuning
linoytsaban 407459e
Merge branch 'main' into flux-fine-tuning
linoytsaban 259e443
fix imports of static methods
linoytsaban f3f3c7b
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban 3d5b713
fix pipeline loading in readme, remove auto1111 docs for now
linoytsaban cdbb69c
fix pipeline loading in readme, remove auto1111 docs for now, remove …
linoytsaban b249d36
Update examples/dreambooth/train_dreambooth_flux.py
linoytsaban 37b8e64
fix te2 loading and remove te2 refs from text encoder training
linoytsaban 0714278
fix tokenizer_2 initialization
linoytsaban 80c3fe0
remove text_encoder training refs from lora script (for now)
linoytsaban 77a0235
try with vae in bfloat16, fix model hook save
linoytsaban 2fcba98
Merge branch 'main' into flux-fine-tuning
linoytsaban a64f4a5
fix tokenization
linoytsaban 18a23d4
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban 08a1296
fix static imports
linoytsaban e474683
fix CLIP import
linoytsaban 97f1546
remove text_encoder training refs (for now) from lora script
linoytsaban 187e421
fix minor bug in encode_prompt, add guidance def in lora script, ...
linoytsaban b24f673
fix unpack_latents args
linoytsaban 72cb859
Merge branch 'main' into flux-fine-tuning
linoytsaban ad1d236
fix license in readme
linoytsaban bcb752b
add "none" to weighting_scheme options for uniform sampling
linoytsaban df880f3
style
linoytsaban e69244a
adapt model saving - remove text encoder refs
linoytsaban 6612532
adapt model loading - remove text encoder refs
linoytsaban 155dbb2
initial commit for readme
linoytsaban c7ea620
Merge branch 'main' into flux-fine-tuning
sayakpaul 1526141
Update examples/dreambooth/train_dreambooth_lora_flux.py
linoytsaban 6b78e19
Update examples/dreambooth/train_dreambooth_lora_flux.py
linoytsaban a2ac0eb
fix vae casting
linoytsaban 7f0fe8a
remove precondition_outputs
linoytsaban d0fb727
readme
linoytsaban dcd26d1
readme
linoytsaban 60c5b65
style
linoytsaban aac9183
readme
linoytsaban 0ebc3dd
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban 56059d8
readme
linoytsaban 911306a
update weighting scheme default & docs
linoytsaban 8e4d230
style
linoytsaban 573026c
add text_encoder training to lora script, change vae_scale_factor val…
linoytsaban aea5d1f
style
linoytsaban bde7ded
text encoder training fixes
linoytsaban 4e35d2b
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban f781630
style
linoytsaban 5a86a55
Merge branch 'main' into flux-fine-tuning
linoytsaban d77b67f
update readme
linoytsaban 1d8e25f
minor fixes
linoytsaban dc1b10e
fix te params
linoytsaban 569f2e1
fix te params
linoytsaban c7097ab
Merge branch 'main' into flux-fine-tuning
linoytsaban File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# DreamBooth training example for FLUX.1 [dev] | ||
|
||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. | ||
|
||
The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script. | ||
> [!NOTE] | ||
> **Memory consumption** | ||
> | ||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - | ||
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. | ||
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md) | ||
|
||
|
||
> [!NOTE] | ||
> **Gated model** | ||
> | ||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: | ||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform. | ||
|
||
## Running locally with PyTorch | ||
|
||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
**Important** | ||
|
||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | ||
|
||
```bash | ||
git clone https://github.com/huggingface/diffusers | ||
cd diffusers | ||
pip install -e . | ||
``` | ||
|
||
Then cd in the `examples/dreambooth` folder and run | ||
```bash | ||
pip install -r requirements_flux.txt | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
Or for a default accelerate configuration without answering questions about your environment | ||
|
||
```bash | ||
accelerate config default | ||
``` | ||
|
||
Or if your environment doesn't support an interactive shell (e.g., a notebook) | ||
|
||
```python | ||
from accelerate.utils import write_basic_config | ||
write_basic_config() | ||
``` | ||
|
||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. | ||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. | ||
|
||
|
||
### Dog toy example | ||
|
||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. | ||
|
||
Let's first download it locally: | ||
|
||
```python | ||
from huggingface_hub import snapshot_download | ||
|
||
local_dir = "./dog" | ||
snapshot_download( | ||
"diffusers/dog-example", | ||
local_dir=local_dir, repo_type="dataset", | ||
ignore_patterns=".gitattributes", | ||
) | ||
``` | ||
|
||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. | ||
|
||
Now, we can launch training using: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```bash | ||
export MODEL_NAME="black-forest-labs/FLUX.1-dev" | ||
export INSTANCE_DIR="dog" | ||
export OUTPUT_DIR="trained-flux" | ||
|
||
accelerate launch train_dreambooth_flux.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--mixed_precision="fp16" \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--resolution=1024 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--learning_rate=1e-4 \ | ||
--report_to="wandb" \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--max_train_steps=500 \ | ||
--validation_prompt="A photo of sks dog in a bucket" \ | ||
--validation_epochs=25 \ | ||
--seed="0" \ | ||
--push_to_hub | ||
``` | ||
|
||
To better track our training experiments, we're using the following flags in the command above: | ||
|
||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. | ||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. | ||
|
||
> [!NOTE] | ||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. | ||
|
||
> [!TIP] | ||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. | ||
|
||
## LoRA + DreamBooth | ||
|
||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. | ||
|
||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. | ||
|
||
To perform DreamBooth with LoRA, run: | ||
|
||
```bash | ||
export MODEL_NAME="black-forest-labs/FLUX.1-dev" | ||
export INSTANCE_DIR="dog" | ||
export OUTPUT_DIR="trained-flux-lora" | ||
|
||
accelerate launch train_dreambooth_lora_flux.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--mixed_precision="fp16" \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--learning_rate=1e-5 \ | ||
--report_to="wandb" \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--max_train_steps=500 \ | ||
--validation_prompt="A photo of sks dog in a bucket" \ | ||
--validation_epochs=25 \ | ||
--seed="0" \ | ||
--push_to_hub | ||
``` | ||
|
||
### Text Encoder Training | ||
|
||
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported. | ||
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: | ||
|
||
> [!NOTE] | ||
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). | ||
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. | ||
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. | ||
|
||
To perform DreamBooth LoRA with text-encoder training, run: | ||
```bash | ||
export MODEL_NAME="black-forest-labs/FLUX.1-dev" | ||
export OUTPUT_DIR="trained-flux-dev-dreambooth-lora" | ||
|
||
accelerate launch train_dreambooth_lora_flux.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--mixed_precision="fp16" \ | ||
--train_text_encoder\ | ||
--instance_prompt="a photo of sks dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--learning_rate=1e-5 \ | ||
--report_to="wandb" \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--max_train_steps=500 \ | ||
--validation_prompt="A photo of sks dog in a bucket" \ | ||
--seed="0" \ | ||
--push_to_hub | ||
``` | ||
|
||
## Other notes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO? |
||
Thanks to `bghira` for their help with reviewing & insight sharing ♥️ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also add a disclaimer about the high memory requirements here. And refer people to @bghira's guide in case people are interested to try it out on a resource-constrained device.