Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions examples/consistency_distillation/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Latent Consistency Distillation Example:

[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.

## Full model distillation

Expand All @@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
Expand All @@ -46,20 +46,24 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.


#### Example with LAION-A6+ dataset
#### Example

The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.

```bash
runwayml/stable-diffusion-v1-5
PROGRAM="train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the hyperparameters in the examples work well with the CC12M dataset or does it potentially make sense to revisit them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to be honest. I haven't experimented much with it, considering that CC12M is very small for current standards.

Perhaps it would be easier to include a disclaimer stating that this dataset is used for illustrative purposes and users are encouraged to bring their own.

--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
Expand All @@ -69,19 +73,23 @@ PROGRAM="train_lcm_distill_sd_wds.py \
--resume_from_checkpoint=latest \
--report_to=wandb \
--seed=453645634 \
--push_to_hub \
--push_to_hub
```

## LCM-LoRA

Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.

### Example with LAION-A6+ dataset

### Example

The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).

```bash
runwayml/stable-diffusion-v1-5
PROGRAM="train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_lora_sd_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
Expand All @@ -90,7 +98,7 @@ PROGRAM="train_lcm_distill_lora_sd_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
Expand Down
32 changes: 20 additions & 12 deletions examples/consistency_distillation/README_sdxl.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Latent Consistency Distillation Example:

[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.

## Full model distillation

Expand All @@ -24,7 +24,7 @@ Then cd in the example folder and run
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
Expand All @@ -46,12 +46,16 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.


#### Example with LAION-A6+ dataset
#### Example

The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.

```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_NAME \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
Expand All @@ -60,7 +64,7 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
Expand All @@ -77,11 +81,15 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \

Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.

### Example with LAION-A6+ dataset

### Example

The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).

```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir=$OUTPUT_DIR \
Expand All @@ -92,7 +100,7 @@ PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text, _, _ = batch
image, text = batch

image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@

MAX_SEQ_LENGTH = 77

# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images

if is_wandb_available():
import wandb

Expand Down Expand Up @@ -146,10 +151,10 @@ def __call__(self, x):
try:
if "json" in x:
x_json = json.loads(x["json"])
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
WDS_JSON_HEIGHT, 0
) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
return filter_size and filter_watermark
else:
return False
Expand Down Expand Up @@ -180,7 +185,7 @@ def get_orig_size(json):
if use_fix_crop_and_size:
return (resolution, resolution)
else:
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))

def transform(example):
# resize image
Expand Down Expand Up @@ -212,7 +217,7 @@ def transform(example):
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=960)),
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text, _, _ = batch
image, text = batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, is this change because the SD distillation script currently has a bug where it assumes that the Text2ImageDataset.train_dataloader batch consists of (image, text, orig_size, crop_coords) like the SDXL dataloader instead of just (image, text):

wds.map(filter_keys({"image", "text"})),
wds.map(transform),
wds.to_tuple("image", "text"),

and is otherwise unrelated to CC12M support?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like 711e468 mentions this. Feel free to close :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly, it's actually a bug and not directly related to CC12M :)


image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)
Expand Down
15 changes: 10 additions & 5 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@

MAX_SEQ_LENGTH = 77

# Adjust for your dataset
WDS_JSON_WIDTH = "width" # original_width for LAION
WDS_JSON_HEIGHT = "height" # original_height for LAION
MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images

if is_wandb_available():
import wandb

Expand Down Expand Up @@ -128,10 +133,10 @@ def __call__(self, x):
try:
if "json" in x:
x_json = json.loads(x["json"])
filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
"original_height", 0
filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(
WDS_JSON_HEIGHT, 0
) >= self.min_size
filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark
return filter_size and filter_watermark
else:
return False
Expand Down Expand Up @@ -162,7 +167,7 @@ def get_orig_size(json):
if use_fix_crop_and_size:
return (resolution, resolution)
else:
return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))

def transform(example):
# resize image
Expand Down Expand Up @@ -194,7 +199,7 @@ def transform(example):
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.select(WebdatasetFilter(min_size=960)),
wds.select(WebdatasetFilter(min_size=MIN_SIZE)),
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
Expand Down