From a009f1d1fe03fe622b57de5e53cbe283257f91ec Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 25 Mar 2023 09:37:05 +0530 Subject: [PATCH 01/19] improve stable unclip doc. --- .../source/en/api/pipelines/stable_unclip.mdx | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_unclip.mdx b/docs/source/en/api/pipelines/stable_unclip.mdx index c8b5d58705ba..372242ae2dff 100644 --- a/docs/source/en/api/pipelines/stable_unclip.mdx +++ b/docs/source/en/api/pipelines/stable_unclip.mdx @@ -42,12 +42,9 @@ Coming soon! ### Text guided Image-to-Image Variation ```python -import requests -import torch -from PIL import Image -from io import BytesIO - from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" @@ -55,12 +52,10 @@ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( pipe = pipe.to("cuda") url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" - -response = requests.get(url) -init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = load_image(url) images = pipe(init_image).images -images[0].save("fantasy_landscape.png") +images[0].save("variation_image.png") ``` Optionally, you can also pass a prompt to `pipe` such as: @@ -69,7 +64,50 @@ Optionally, you can also pass a prompt to `pipe` such as: prompt = "A fantasy landscape, trending on artstation" images = pipe(init_image, prompt=prompt).images -images[0].save("fantasy_landscape.png") +images[0].save("variation_image_two.png") +``` + +### Memory optimization + +If you are short on GPU memory, you can enable smart CPU offloading so that models that are not needed +immediately for a computation can be offloaded to CPU: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +# Offload to CPU. +pipe.enable_model_cpu_offload() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] +``` + +Further memory optimizations are possible by enabling VAE slicing on the pipeline: + +```python +from diffusers import StableUnCLIPImg2ImgPipeline +from diffusers.utils import load_image +import torch + +pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16" +) +pipe.enable_model_cpu_offload() +pipe.enable_vae_slicing() + +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png" +init_image = load_image(url) + +images = pipe(init_image).images +images[0] ``` ### StableUnCLIPPipeline From c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 11:25:48 +0530 Subject: [PATCH 02/19] feat: support for applying min-snr weighting for faster convergence. --- examples/text_to_image/train_text_to_image.py | 63 +++++++++++++++++-- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6139a0e6514d..c7ebf65edcb9 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -51,6 +51,10 @@ logger = get_logger(__name__, log_level="INFO") +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") @@ -193,6 +197,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--_snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -325,9 +336,32 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -dataset_name_mapping = { - "lambdalabs/pokemon-blip-captions": ("image", "text"), -} +def expand_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + """ + res = arr.to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def compute_snr(noise_scheduler): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + def fn(timesteps): + alpha = expand_tensor(sqrt_alphas_cumprod, timesteps, timesteps.shape) + sigma = expand_tensor(sqrt_one_minus_alphas_cumprod, timesteps, timesteps.shape) + snr = (alpha / sigma) ** 2 + return snr + + return fn def main(): @@ -476,6 +510,9 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + if args.snr_gamma is not None: + snr_fn = compute_snr(noise_scheduler) + # Initialize the optimizer if args.use_8bit_adam: try: @@ -526,7 +563,7 @@ def load_model_hook(models, input_dir): column_names = dataset["train"].column_names # 6. Get the column names for input/target. - dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.image_column is None: image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: @@ -734,7 +771,23 @@ def collate_fn(examples): # Predict the noise residual and compute loss model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = snr_fn(timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = (mse_loss_weights * loss).mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() From ca0c158232bf91cb114e19d106bd774231ef7e96 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 11:58:59 +0530 Subject: [PATCH 03/19] add: support for validation logging with wandb --- examples/text_to_image/train_text_to_image.py | 103 +++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c7ebf65edcb9..527745f2175e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -42,10 +42,14 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, deprecate +from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.15.0.dev0") @@ -116,6 +120,13 @@ def parse_args(): "value if set." ), ) + parser.add_argument( + "--validation_promptss", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."), + ) parser.add_argument( "--output_dir", type=str, @@ -309,6 +320,22 @@ def parse_args(): "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + required=True, + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -364,6 +391,57 @@ def fn(timesteps): return fn +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, step): + logger.info("Running validation... ") + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=accelerate.unwrap(unet), + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + def main(): args = parse_args() @@ -682,7 +760,9 @@ def collate_fn(examples): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompt") + accelerator.init_trackers(args.tracker_project_name, tracker_config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -816,6 +896,25 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 76e94461b095494d467d727afc8f7f6945a972f8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 12:47:54 +0530 Subject: [PATCH 04/19] make not a required arg. --- examples/text_to_image/train_text_to_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 527745f2175e..db39d1979eb8 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -330,7 +330,6 @@ def parse_args(): "--tracker_project_name", type=str, default="text2image-fine-tune", - required=True, help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" From 052bc8848d51509304384d44d713397876a2fd7c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 13:11:58 +0530 Subject: [PATCH 05/19] fix: arg name. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index db39d1979eb8..c04cd62f9076 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -121,7 +121,7 @@ def parse_args(): ), ) parser.add_argument( - "--validation_promptss", + "--validation_prompts", type=str, default=None, nargs="+", From c4811476c007e071712053fdc5c9acb1faae93a5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 13:19:04 +0530 Subject: [PATCH 06/19] fix: cli args. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c04cd62f9076..e3d03157ec2b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -209,7 +209,7 @@ def parse_args(): "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( - "--_snr_gamma", + "--snr_gamma", type=float, default=None, help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " From 835b5eee268fbf9ce6dd02e99c048555ac2d0318 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 13:21:22 +0530 Subject: [PATCH 07/19] fix: tracker config. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e3d03157ec2b..c3c391260ea7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -760,7 +760,7 @@ def collate_fn(examples): # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) - tracker_config.pop("validation_prompt") + tracker_config.pop("validation_prompts") accelerator.init_trackers(args.tracker_project_name, tracker_config) # Train! From 7c842f2adf47f7f56a3ac67b7804a94464826a57 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 13:31:34 +0530 Subject: [PATCH 08/19] fix: loss calculation. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c3c391260ea7..e67f90c303cf 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -866,7 +866,7 @@ def collate_fn(examples): # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - loss = (mse_loss_weights * loss).mean() + loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() From 3f078bc89a426f3ecd7165b068d305f67bfe7da0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 13:47:52 +0530 Subject: [PATCH 09/19] fix: validation logging. --- examples/text_to_image/train_text_to_image.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e67f90c303cf..56b74db6d467 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -895,24 +895,25 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - log_validation( - vae, - text_encoder, - tokenizer, - unet, - args, - accelerator, - weight_dtype, - global_step, - ) - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 1d9f3bc34a09fe029f4f40bd9be096bca45cb1c4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 13:51:51 +0530 Subject: [PATCH 10/19] fix: unwrap call. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 56b74db6d467..a494672bf392 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -398,7 +398,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - unet=accelerate.unwrap(unet), + unet=accelerator.unwrap(unet), safety_checker=None, revision=args.revision, torch_dtype=weight_dtype, From d2ce5e697bd07f6469b816d6053c51016f5221c2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 14:01:07 +0530 Subject: [PATCH 11/19] fix: validation logging. --- examples/text_to_image/train_text_to_image.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a494672bf392..f22e27966473 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -398,7 +398,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - unet=accelerator.unwrap(unet), + unet=accelerator.unwrap_model(unet), safety_checker=None, revision=args.revision, torch_dtype=weight_dtype, @@ -895,32 +895,32 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - if accelerator.is_main_process: - if args.validation_prompts is not None and global_step % args.validation_steps == 0: - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - log_validation( - vae, - text_encoder, - tokenizer, - unet, - args, - accelerator, - weight_dtype, - global_step, - ) - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: From 667d23da54de525c26fd24bdde55d897c253be22 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 15:43:18 +0530 Subject: [PATCH 12/19] fix: internval. --- examples/text_to_image/train_text_to_image.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index f22e27966473..eef9df171ce3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -125,7 +125,7 @@ def parse_args(): type=str, default=None, nargs="+", - help=("A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."), + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), ) parser.add_argument( "--output_dir", @@ -321,10 +321,10 @@ def parse_args(): ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument( - "--validation_steps", + "--validation_epochs", type=int, - default=100, - help="Run validation every X steps.", + default=5, + help="Run validation every X epochs.", ) parser.add_argument( "--tracker_project_name", @@ -390,7 +390,7 @@ def fn(timesteps): return fn -def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, step): +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): logger.info("Running validation... ") pipeline = StableDiffusionPipeline.from_pretrained( @@ -424,7 +424,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") elif tracker.name == "wandb": tracker.log( { @@ -902,7 +902,7 @@ def collate_fn(examples): break if accelerator.is_main_process: - if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) From a154335c3bf91fad4f270eefb35cdbeaeac07739 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 30 Mar 2023 15:45:15 +0530 Subject: [PATCH 13/19] fix: checkpointing push to hub. --- examples/text_to_image/train_text_to_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index eef9df171ce3..e5aba8beb908 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -500,6 +500,8 @@ def main(): gitignore.write("step_*\n") if "epoch_*" not in gitignore: gitignore.write("epoch_*\n") + if "checkpoint-*" not in gitignore: + gitignore.write("checkpoint-*\n") elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) @@ -938,7 +940,7 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + repo.push_to_hub(commit_message="End of training", blocking=True, auto_lfs_prune=True) accelerator.end_training() From ad3fb9283a6da8f19818583eedf45a710e3bda0a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 31 Mar 2023 17:48:45 +0530 Subject: [PATCH 14/19] fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e45577bf4c24dee06b1a4a2f5c050\#commitcomment-106913193 --- examples/text_to_image/train_text_to_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e5aba8beb908..af410fe1a0b0 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -860,9 +860,9 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = snr_fn(timesteps) - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. From f91f6bd1ef954eef1f8fadea995b42e3db1a39a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 4 Apr 2023 09:06:38 +0530 Subject: [PATCH 15/19] fix: norm group test for UNet3D. --- tests/models/test_models_unet_3d_condition.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 729367a0c164..5a0d74a3ea5a 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -119,12 +119,11 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - # Overriding because `block_out_channels` needs to be different for this model. + # Overriding to set `norm_num_groups` needs to be different for this model. def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 32 - init_dict["block_out_channels"] = (32, 64, 64, 64) model = self.model_class(**init_dict) model.to(torch_device) From 7434dcd4aa07b7bd447b5db0a407e9295c32757c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Apr 2023 09:31:53 +0530 Subject: [PATCH 16/19] address PR comments. --- examples/text_to_image/train_text_to_image.py | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index af410fe1a0b0..8abbe011a29f 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -542,6 +542,30 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -589,8 +613,8 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - if args.snr_gamma is not None: - snr_fn = compute_snr(noise_scheduler) + # if args.snr_gamma is not None: + # snr_fn = compute_snr(noise_scheduler) # Initialize the optimizer if args.use_8bit_adam: @@ -859,7 +883,7 @@ def collate_fn(examples): # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. - snr = snr_fn(timesteps) + snr = compute_snr(timesteps) mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) From db8bbbda5b648bf1db5bc39730b6a6db2758885b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Apr 2023 15:01:55 +0530 Subject: [PATCH 17/19] remove unneeded code. --- examples/text_to_image/train_text_to_image.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e1ceef4b921e..d4d8dae608e3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -564,9 +564,6 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - # if args.snr_gamma is not None: - # snr_fn = compute_snr(noise_scheduler) - # Initialize the optimizer if args.use_8bit_adam: try: From 96e725433b2bad026c1d5104bcbaa406907dbf72 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Apr 2023 17:22:07 +0530 Subject: [PATCH 18/19] add: entry in the readme and docs. --- docs/source/en/training/text2image.mdx | 22 ++++++++++++++++++++++ examples/text_to_image/README.md | 16 ++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/docs/source/en/training/text2image.mdx b/docs/source/en/training/text2image.mdx index 851be61bcf97..9e4e85718848 100644 --- a/docs/source/en/training/text2image.mdx +++ b/docs/source/en/training/text2image.mdx @@ -155,6 +155,28 @@ python train_text_to_image_flax.py \ +## Training with better convergence + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence +by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended +value when using it is 5.0. + +You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups: + +* Training without the Min-SNR weighting strategy +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0) +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0) + +For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. + +Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. + + + +Training with Min-SNR weighting strategy is only supported in PyTorch. + + + ## LoRA You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide. diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 0c378ffde2e5..004e54b3a072 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -111,6 +111,22 @@ image = pipe(prompt="yoda").images[0] image.save("yoda-pokemon.png") ``` +#### Training with better convergence + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence +by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended +value when using it is 5.0. + +You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups: + +* Training without the Min-SNR weighting strategy +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0) +* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0) + +For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced. + +Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. + ## Training with LoRA Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. From 245b558b43acd6e2178bd26bcb11e9f79c6dc992 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Apr 2023 18:12:03 +0530 Subject: [PATCH 19/19] Apply suggestions from code review Co-authored-by: Suraj Patil --- docs/source/en/training/text2image.mdx | 2 +- examples/text_to_image/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/training/text2image.mdx b/docs/source/en/training/text2image.mdx index 9e4e85718848..4f57ccf94de0 100644 --- a/docs/source/en/training/text2image.mdx +++ b/docs/source/en/training/text2image.mdx @@ -155,7 +155,7 @@ python train_text_to_image_flax.py \ -## Training with better convergence +## Training with Min-SNR weighting We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 004e54b3a072..c84db0ceee64 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -111,7 +111,7 @@ image = pipe(prompt="yoda").images[0] image.save("yoda-pokemon.png") ``` -#### Training with better convergence +#### Training with Min-SNR weighting We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended