From 418729d19ebd98625657dac8f7afee1f7d551688 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 21 Jul 2024 08:22:14 +0530 Subject: [PATCH 1/4] SD3 training fixes Co-authored-by: bghira <59658056+bghira@users.noreply.github.com> --- examples/dreambooth/README_sd3.md | 4 +++- .../dreambooth/train_dreambooth_lora_sd3.py | 17 ++++++++++++++--- examples/dreambooth/train_dreambooth_sd3.py | 18 +++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index e3d2247d974e..de07ff5ba1a6 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \ ## Other notes -We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. \ No newline at end of file +1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. +2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in . +3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well. \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 5401ee570a34..31c5efdc0323 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -523,6 +523,13 @@ def parse_args(input_args=None): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) + parser.add_argument( + "--precondition_outputs", + type=int, + default=1, + help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " + "model `target` is calculated.", + ) parser.add_argument( "--optimizer", type=str, @@ -1636,7 +1643,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1670,14 +1677,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. - model_pred = model_pred * (-sigmas) + noisy_model_input + if args.precondition_outputs: + model_pred = model_pred * (-sigmas) + noisy_model_input # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss - target = model_input + if args.precondition_outputs: + target = model_input + else: + target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 9a72294c20bd..9092b1d6b079 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -494,6 +494,13 @@ def parse_args(input_args=None): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) + parser.add_argument( + "--precondition_outputs", + type=int, + default=1, + help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " + "model `target` is calculated.", + ) parser.add_argument( "--optimizer", type=str, @@ -1549,7 +1556,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents @@ -1598,13 +1605,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. - model_pred = model_pred * (-sigmas) + noisy_model_input + if args.precondition_outputs: + model_pred = model_pred * (-sigmas) + noisy_model_input + # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss - target = model_input + if args.precondition_outputs: + target = model_input + else: + target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. From 905009305252e05afa73528e7850de0cd4211eeb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 21 Jul 2024 08:28:35 +0530 Subject: [PATCH 2/4] rewrite noise addition part to respect the eqn. --- examples/dreambooth/train_dreambooth_lora_sd3.py | 3 ++- examples/dreambooth/train_dreambooth_sd3.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 31c5efdc0323..5d5762a6f4a6 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1663,8 +1663,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual model_pred = transformer( diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 9092b1d6b079..17eca6da8891 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1576,8 +1576,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual if not args.train_text_encoder: From 93f40c403c8e4e7791c898511411633f9ae80c10 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 21 Jul 2024 08:29:38 +0530 Subject: [PATCH 3/4] styler --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 5d5762a6f4a6..1d7078628e4c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1665,7 +1665,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual model_pred = transformer( diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 17eca6da8891..ebd30468b313 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1578,7 +1578,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Predict the noise residual if not args.train_text_encoder: From 08d7501f80f014ad5229740ed48c5e12f66d96ef Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 21 Jul 2024 14:00:22 +0530 Subject: [PATCH 4/4] Update examples/dreambooth/README_sd3.md Co-authored-by: Kashif Rasul --- examples/dreambooth/README_sd3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index de07ff5ba1a6..052e383ef6f0 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -184,5 +184,5 @@ accelerate launch train_dreambooth_lora_sd3.py \ ## Other notes 1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. -2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in . +2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917). 3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well. \ No newline at end of file