Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \

## Other notes

We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.
20 changes: 16 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,13 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--precondition_outputs",
type=int,
default=1,
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
"model `target` is calculated.",
)
parser.add_argument(
"--optimizer",
type=str,
Expand Down Expand Up @@ -1636,7 +1643,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)

# Sample noise that we'll add to the latents
Expand All @@ -1656,8 +1663,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

# Predict the noise residual
model_pred = transformer(
Expand All @@ -1670,14 +1678,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
if args.precondition_outputs:
model_pred = model_pred * (-sigmas) + noisy_model_input

# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)

# flow matching loss
target = model_input
if args.precondition_outputs:
target = model_input
else:
target = noise - model_input

if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
Expand Down
21 changes: 17 additions & 4 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--precondition_outputs",
type=int,
default=1,
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
"model `target` is calculated.",
)
parser.add_argument(
"--optimizer",
type=str,
Expand Down Expand Up @@ -1549,7 +1556,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)

# Sample noise that we'll add to the latents
Expand All @@ -1569,8 +1576,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

# Predict the noise residual
if not args.train_text_encoder:
Expand Down Expand Up @@ -1598,13 +1606,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
if args.precondition_outputs:
model_pred = model_pred * (-sigmas) + noisy_model_input

# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)

# flow matching loss
target = model_input
if args.precondition_outputs:
target = model_input
else:
target = noise - model_input

if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
Expand Down