Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
ba28006
add noise_offset param
linoytsaban Jan 22, 2024
e3d81d6
micro conditioning - wip
linoytsaban Jan 24, 2024
5501c93
Merge branch 'huggingface:main' into advanced_script_new_features
linoytsaban Jan 24, 2024
72a7586
image processing adjusted and moved to support micro conditioning
linoytsaban Jan 24, 2024
b0d2fc1
change time ids to be computed inside train loop
linoytsaban Jan 24, 2024
d42e178
Merge branch 'main' into advanced_script_new_features
linoytsaban Jan 24, 2024
3217ec2
change time ids to be computed inside train loop
linoytsaban Jan 24, 2024
3a61265
change time ids to be computed inside train loop
linoytsaban Jan 24, 2024
aea7181
time ids shape fix
linoytsaban Jan 24, 2024
b704379
move token replacement of validation prompt to the same section of in…
linoytsaban Jan 25, 2024
ea2fc04
add offset noise to sd15 advanced script
linoytsaban Jan 25, 2024
3140ed7
fix token loading during validation
linoytsaban Jan 25, 2024
8af2295
fix token loading during validation in sdxl script
linoytsaban Jan 25, 2024
5e8bf00
a little clean
linoytsaban Jan 25, 2024
eb4fb29
Merge branch 'huggingface:main' into advanced_script_new_features
linoytsaban Jan 25, 2024
3b1bb3c
style
linoytsaban Jan 25, 2024
ac6fb6d
a little clean
linoytsaban Jan 25, 2024
c6808fd
style
linoytsaban Jan 25, 2024
08137fe
Merge branch 'huggingface:main' into advanced_script_new_features
linoytsaban Jan 26, 2024
319f576
sdxl script - a little clean + minor path fix
linoytsaban Jan 26, 2024
c2b14f9
Merge remote-tracking branch 'origin/advanced_script_new_features' in…
linoytsaban Jan 26, 2024
b1b42a8
ad 1.5 script - minor path fix
linoytsaban Jan 26, 2024
d21893e
fix missing comma in code example in model card
linoytsaban Jan 26, 2024
1922313
clean up commented lines
linoytsaban Jan 28, 2024
b56fa6d
style
linoytsaban Jan 28, 2024
2b2f225
Merge branch 'main' into advanced_script_new_features
sayakpaul Jan 29, 2024
9edfc60
remove time ids computed outside training loop - no longer used now t…
linoytsaban Jan 29, 2024
b8e395c
Merge remote-tracking branch 'origin/advanced_script_new_features' in…
linoytsaban Jan 29, 2024
01eeaf4
style
linoytsaban Jan 29, 2024
d8ffdba
[WIP] - added draft readme, building off of examples/dreambooth/READM…
linoytsaban Jan 29, 2024
ee96926
Merge branch 'main' into advanced_script_new_features
linoytsaban Jan 30, 2024
7973a48
readme
linoytsaban Jan 31, 2024
3e903dd
Merge remote-tracking branch 'origin/advanced_script_new_features' in…
linoytsaban Jan 31, 2024
2fda589
Merge branch 'main' into advanced_script_new_features
linoytsaban Jan 31, 2024
b497c2d
readme
linoytsaban Jan 31, 2024
208c3fe
readme
linoytsaban Jan 31, 2024
d36ec92
readme
linoytsaban Jan 31, 2024
07b40e6
readme
linoytsaban Jan 31, 2024
01fde2b
readme
linoytsaban Jan 31, 2024
d6f103c
readme
linoytsaban Jan 31, 2024
a6e079e
readme
linoytsaban Jan 31, 2024
4974e43
Merge branch 'main' into advanced_script_new_features
linoytsaban Jan 31, 2024
ff17aa9
removed --crops_coords_top_left from CLI args
linoytsaban Jan 31, 2024
bb6ca0d
Merge remote-tracking branch 'origin/advanced_script_new_features' in…
linoytsaban Jan 31, 2024
dca9739
style
linoytsaban Jan 31, 2024
9ed0a65
Merge branch 'main' into advanced_script_new_features
linoytsaban Jan 31, 2024
c5de519
fix missing shape bug due to missing RGB if statement
linoytsaban Feb 1, 2024
dccb8c0
Merge branch 'main' into advanced_script_new_features
linoytsaban Feb 1, 2024
93e508b
add blog mention at the start of the reamde as well
linoytsaban Feb 2, 2024
2daf9aa
Update examples/advanced_diffusion_training/README.md
linoytsaban Feb 3, 2024
ce5eaa5
Merge branch 'main' into advanced_script_new_features
linoytsaban Feb 3, 2024
70ca0f4
change note to render nicely as well
linoytsaban Feb 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions examples/advanced_diffusion_training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Advanced diffusion training examples

## Train Dreambooth LoRA with Stable Diffusion XL
> [!TIP]
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗

[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.

LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.

The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️

> [!NOTE]
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)

📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)


## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```

Then cd in the `examples/advanced_diffusion_training` folder and run
```bash
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell e.g. a notebook

```python
from accelerate.utils import write_basic_config
write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

### Pivotal Tuning
**Training with text encoder(s)**

Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
We then optimize the newly-inserted token embeddings to represent the new concept.

To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
Please keep the following points in mind:

* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.


### 3D icon example

Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.

Let's first download it locally:

```python
from huggingface_hub import snapshot_download

local_dir = "./3d_icon"
snapshot_download(
"LinoyTsaban/3d_icon",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```

Let's review some of the advanced features we're going to be using for this example:
- **custom captions**:
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
```bash
pip install datasets
```

Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")

```
--dataset_name=./3d_icon
--caption_column=prompt
```

You can also load a dataset straight from by specifying it's name in `dataset_name`.
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.

- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
- **pivotal tuning**
- **min SNR gamma**

**Now, we can launch training:**

```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"

accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE_PATH \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--snr_gamma=5.0 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```

To better track our training experiments, we're using the following flags in the command above:

* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.

Our experiments were conducted on a single 40GB A100 GPU.


### Inference

Once training is done, we can perform inference like so:
1. starting with loading the unet lora weights
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from safetensors.torch import load_file

username = "linoyts"
repo_id = f"{username}/3d-icon-SDXL-LoRA"

pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")


pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
```
2. now we load the pivotal tuning embeddings

```python
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]

embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")

state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```

3. let's generate images

```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"

image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```

### Comfy UI / AUTOMATIC1111 Inference
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!

**AUTOMATIC1111 / SD.Next** \
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.

You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.

**ComfyUI** \
In ComfyUI we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
-
### Specifying a better VAE

SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).


### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)

## Running on Colab Notebook
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)

7 changes: 7 additions & 0 deletions examples/advanced_diffusion_training/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
"""
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
- Place it on it on your `embeddings` folder
Expand Down Expand Up @@ -389,7 +388,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--resolution",
type=int,
default=1024,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
Expand Down Expand Up @@ -645,6 +644,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
type=int,
Expand Down Expand Up @@ -745,10 +745,11 @@ def initialize_new_tokens(self, inserting_toks: List[str]):

idx += 1

# copied from train_dreambooth_lora_sdxl_advanced.py
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
Expand Down Expand Up @@ -1634,6 +1635,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
Expand Down Expand Up @@ -1788,6 +1794,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
tokenizer=tokenizer_one,
text_encoder=accelerator.unwrap_model(text_encoder_one),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
Expand Down Expand Up @@ -1860,6 +1867,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)

if args.train_text_encoder_ti:
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)

images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
Expand Down Expand Up @@ -1895,6 +1907,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# load attention processors
pipeline.load_lora_weights(args.output_dir)

# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
Expand All @@ -1917,11 +1941,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
}
)

if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
)

# Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
Expand Down
Loading