diff --git a/examples/consistency_distillation/README.md b/examples/consistency_distillation/README.md index c584736dfe82..d1c874147173 100644 --- a/examples/consistency_distillation/README.md +++ b/examples/consistency_distillation/README.md @@ -1,6 +1,6 @@ # Latent Consistency Distillation Example: -[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference. +[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps. ## Full model distillation @@ -24,7 +24,7 @@ Then cd in the example folder and run pip install -r requirements.txt ``` -And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: +And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with: ```bash accelerate config @@ -46,12 +46,16 @@ write_basic_config() When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. -#### Example with LAION-A6+ dataset +#### Example + +The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use. ```bash -runwayml/stable-diffusion-v1-5 -PROGRAM="train_lcm_distill_sd_wds.py \ - --pretrained_teacher_model=$MODEL_DIR \ +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export OUTPUT_DIR="path/to/saved/model" + +accelerate launch train_lcm_distill_sd_wds.py \ + --pretrained_teacher_model=$MODEL_NAME \ --output_dir=$OUTPUT_DIR \ --mixed_precision=fp16 \ --resolution=512 \ @@ -59,7 +63,7 @@ PROGRAM="train_lcm_distill_sd_wds.py \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ - --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ + --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \ --validation_steps=200 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \ --train_batch_size=12 \ @@ -69,19 +73,23 @@ PROGRAM="train_lcm_distill_sd_wds.py \ --resume_from_checkpoint=latest \ --report_to=wandb \ --seed=453645634 \ - --push_to_hub \ + --push_to_hub ``` ## LCM-LoRA Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model. -### Example with LAION-A6+ dataset - +### Example + +The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). + ```bash -runwayml/stable-diffusion-v1-5 -PROGRAM="train_lcm_distill_lora_sd_wds.py \ - --pretrained_teacher_model=$MODEL_DIR \ +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export OUTPUT_DIR="path/to/saved/model" + +accelerate launch train_lcm_distill_lora_sd_wds.py \ + --pretrained_teacher_model=$MODEL_NAME \ --output_dir=$OUTPUT_DIR \ --mixed_precision=fp16 \ --resolution=512 \ @@ -90,7 +98,7 @@ PROGRAM="train_lcm_distill_lora_sd_wds.py \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ - --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ + --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \ --validation_steps=200 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \ --train_batch_size=12 \ diff --git a/examples/consistency_distillation/README_sdxl.md b/examples/consistency_distillation/README_sdxl.md index 00577f9fa2b8..4d2177669a90 100644 --- a/examples/consistency_distillation/README_sdxl.md +++ b/examples/consistency_distillation/README_sdxl.md @@ -1,6 +1,6 @@ # Latent Consistency Distillation Example: -[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference. +[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps. ## Full model distillation @@ -24,7 +24,7 @@ Then cd in the example folder and run pip install -r requirements.txt ``` -And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: +And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with: ```bash accelerate config @@ -46,12 +46,16 @@ write_basic_config() When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. -#### Example with LAION-A6+ dataset +#### Example + +The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use. ```bash -export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" -PROGRAM="train_lcm_distill_sdxl_wds.py \ - --pretrained_teacher_model=$MODEL_DIR \ +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export OUTPUT_DIR="path/to/saved/model" + +accelerate launch train_lcm_distill_sdxl_wds.py \ + --pretrained_teacher_model=$MODEL_NAME \ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ --output_dir=$OUTPUT_DIR \ --mixed_precision=fp16 \ @@ -60,7 +64,7 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ - --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ + --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \ --validation_steps=200 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \ --train_batch_size=12 \ @@ -77,11 +81,15 @@ PROGRAM="train_lcm_distill_sdxl_wds.py \ Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model. -### Example with LAION-A6+ dataset - +### Example + +The following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). + ```bash -export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" -PROGRAM="train_lcm_distill_lora_sdxl_wds.py \ +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export OUTPUT_DIR="path/to/saved/model" + +accelerate launch train_lcm_distill_lora_sdxl_wds.py \ --pretrained_teacher_model=$MODEL_DIR \ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ --output_dir=$OUTPUT_DIR \ @@ -92,7 +100,7 @@ PROGRAM="train_lcm_distill_lora_sdxl_wds.py \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ - --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \ + --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \ --validation_steps=200 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \ --train_batch_size=12 \ diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 6fa8d2c57832..62b4e9a4bf0b 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -1123,7 +1123,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - image, text, _, _ = batch + image, text = batch image = image.to(accelerator.device, non_blocking=True) encoded_text = compute_embeddings_fn(text) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 25faedf714b9..3a7712ee92d8 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -68,6 +68,11 @@ MAX_SEQ_LENGTH = 77 +# Adjust for your dataset +WDS_JSON_WIDTH = "width" # original_width for LAION +WDS_JSON_HEIGHT = "height" # original_height for LAION +MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images + if is_wandb_available(): import wandb @@ -146,10 +151,10 @@ def __call__(self, x): try: if "json" in x: x_json = json.loads(x["json"]) - filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( - "original_height", 0 + filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get( + WDS_JSON_HEIGHT, 0 ) >= self.min_size - filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark + filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark return filter_size and filter_watermark else: return False @@ -180,7 +185,7 @@ def get_orig_size(json): if use_fix_crop_and_size: return (resolution, resolution) else: - return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0))) + return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) def transform(example): # resize image @@ -212,7 +217,7 @@ def transform(example): pipeline = [ wds.ResampledShards(train_shards_path_or_url), tarfile_to_samples_nothrow, - wds.select(WebdatasetFilter(min_size=960)), + wds.select(WebdatasetFilter(min_size=MIN_SIZE)), wds.shuffle(shuffle_buffer_size), *processing_pipeline, wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index ec4bf432f03d..f6e2882ef8e4 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -1097,7 +1097,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - image, text, _, _ = batch + image, text = batch image = image.to(accelerator.device, non_blocking=True) encoded_text = compute_embeddings_fn(text) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 7d2b1e103208..a75f2bd8f881 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -67,6 +67,11 @@ MAX_SEQ_LENGTH = 77 +# Adjust for your dataset +WDS_JSON_WIDTH = "width" # original_width for LAION +WDS_JSON_HEIGHT = "height" # original_height for LAION +MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images + if is_wandb_available(): import wandb @@ -128,10 +133,10 @@ def __call__(self, x): try: if "json" in x: x_json = json.loads(x["json"]) - filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( - "original_height", 0 + filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get( + WDS_JSON_HEIGHT, 0 ) >= self.min_size - filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark + filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark return filter_size and filter_watermark else: return False @@ -162,7 +167,7 @@ def get_orig_size(json): if use_fix_crop_and_size: return (resolution, resolution) else: - return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0))) + return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) def transform(example): # resize image @@ -194,7 +199,7 @@ def transform(example): pipeline = [ wds.ResampledShards(train_shards_path_or_url), tarfile_to_samples_nothrow, - wds.select(WebdatasetFilter(min_size=960)), + wds.select(WebdatasetFilter(min_size=MIN_SIZE)), wds.shuffle(shuffle_buffer_size), *processing_pipeline, wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),