Skip to content

[Flux] Dreambooth LoRA training scripts #9086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 59 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
fb5488d
initial commit - dreambooth for flux
linoytsaban Aug 3, 2024
3062daf
update transformer to be FluxTransformer2DModel
linoytsaban Aug 5, 2024
1448aa1
update training loop and validation inference
linoytsaban Aug 5, 2024
a59e012
fix sd3->flux docs
linoytsaban Aug 5, 2024
b225283
Merge branch 'huggingface:main' into flux-fine-tuning
linoytsaban Aug 5, 2024
df074d4
add guidance handling, not sure if it makes sense(?)
linoytsaban Aug 5, 2024
4f7c88c
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban Aug 5, 2024
c897c31
inital dreambooth lora commit
linoytsaban Aug 5, 2024
c3c38a4
fix text_ids in compute_text_embeddings
linoytsaban Aug 5, 2024
5e92ba9
Merge branch 'main' into flux-fine-tuning
linoytsaban Aug 5, 2024
407459e
Merge branch 'main' into flux-fine-tuning
linoytsaban Aug 6, 2024
259e443
fix imports of static methods
linoytsaban Aug 6, 2024
f3f3c7b
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban Aug 6, 2024
3d5b713
fix pipeline loading in readme, remove auto1111 docs for now
linoytsaban Aug 6, 2024
cdbb69c
fix pipeline loading in readme, remove auto1111 docs for now, remove …
linoytsaban Aug 6, 2024
b249d36
Update examples/dreambooth/train_dreambooth_flux.py
linoytsaban Aug 6, 2024
37b8e64
fix te2 loading and remove te2 refs from text encoder training
linoytsaban Aug 6, 2024
0714278
fix tokenizer_2 initialization
linoytsaban Aug 6, 2024
80c3fe0
remove text_encoder training refs from lora script (for now)
linoytsaban Aug 6, 2024
77a0235
try with vae in bfloat16, fix model hook save
linoytsaban Aug 6, 2024
2fcba98
Merge branch 'main' into flux-fine-tuning
linoytsaban Aug 6, 2024
a64f4a5
fix tokenization
linoytsaban Aug 6, 2024
18a23d4
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban Aug 6, 2024
08a1296
fix static imports
linoytsaban Aug 6, 2024
e474683
fix CLIP import
linoytsaban Aug 6, 2024
97f1546
remove text_encoder training refs (for now) from lora script
linoytsaban Aug 6, 2024
187e421
fix minor bug in encode_prompt, add guidance def in lora script, ...
linoytsaban Aug 6, 2024
b24f673
fix unpack_latents args
linoytsaban Aug 7, 2024
72cb859
Merge branch 'main' into flux-fine-tuning
linoytsaban Aug 7, 2024
ad1d236
fix license in readme
linoytsaban Aug 7, 2024
bcb752b
add "none" to weighting_scheme options for uniform sampling
linoytsaban Aug 7, 2024
df880f3
style
linoytsaban Aug 7, 2024
e69244a
adapt model saving - remove text encoder refs
linoytsaban Aug 7, 2024
6612532
adapt model loading - remove text encoder refs
linoytsaban Aug 7, 2024
155dbb2
initial commit for readme
linoytsaban Aug 7, 2024
c7ea620
Merge branch 'main' into flux-fine-tuning
sayakpaul Aug 8, 2024
1526141
Update examples/dreambooth/train_dreambooth_lora_flux.py
linoytsaban Aug 8, 2024
6b78e19
Update examples/dreambooth/train_dreambooth_lora_flux.py
linoytsaban Aug 8, 2024
a2ac0eb
fix vae casting
linoytsaban Aug 8, 2024
7f0fe8a
remove precondition_outputs
linoytsaban Aug 8, 2024
d0fb727
readme
linoytsaban Aug 8, 2024
dcd26d1
readme
linoytsaban Aug 8, 2024
60c5b65
style
linoytsaban Aug 8, 2024
aac9183
readme
linoytsaban Aug 8, 2024
0ebc3dd
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban Aug 8, 2024
56059d8
readme
linoytsaban Aug 8, 2024
911306a
update weighting scheme default & docs
linoytsaban Aug 8, 2024
8e4d230
style
linoytsaban Aug 8, 2024
573026c
add text_encoder training to lora script, change vae_scale_factor val…
linoytsaban Aug 8, 2024
aea5d1f
style
linoytsaban Aug 8, 2024
bde7ded
text encoder training fixes
linoytsaban Aug 8, 2024
4e35d2b
Merge remote-tracking branch 'origin/flux-fine-tuning' into flux-fine…
linoytsaban Aug 8, 2024
f781630
style
linoytsaban Aug 8, 2024
5a86a55
Merge branch 'main' into flux-fine-tuning
linoytsaban Aug 8, 2024
d77b67f
update readme
linoytsaban Aug 8, 2024
1d8e25f
minor fixes
linoytsaban Aug 8, 2024
dc1b10e
fix te params
linoytsaban Aug 8, 2024
569f2e1
fix te params
linoytsaban Aug 8, 2024
c7097ab
Merge branch 'main' into flux-fine-tuning
linoytsaban Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions examples/dreambooth/README_flux.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# DreamBooth training example for FLUX.1 [dev]

[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.

The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a disclaimer about the high memory requirements here. And refer people to @bghira's guide in case people are interested to try it out on a resource-constrained device.

> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md)


> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:

```bash
huggingface-cli login
```

This will also allow us to push the trained model parameters to the Hugging Face Hub platform.

## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```

Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_flux.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell (e.g., a notebook)

```python
from accelerate.utils import write_basic_config
write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.


### Dog toy example

Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.

Let's first download it locally:

```python
from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```

This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.

Now, we can launch training using:

```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux"

accelerate launch train_dreambooth_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```

To better track our training experiments, we're using the following flags in the command above:

* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.

> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.

> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.

## LoRA + DreamBooth

[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.

Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

To perform DreamBooth with LoRA, run:

```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux-lora"

accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```

### Text Encoder Training

Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:

> [!NOTE]
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.

To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export OUTPUT_DIR="trained-flux-dev-dreambooth-lora"

accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--train_text_encoder\
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--seed="0" \
--push_to_hub
```

## Other notes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO?

Thanks to `bghira` for their help with reviewing & insight sharing ♥️
Loading
Loading