From 38ecc5aaaaacd1d32101009c5c91cf54971d1b4f Mon Sep 17 00:00:00 2001 From: thuliu-yt16 Date: Sun, 24 Dec 2023 21:59:05 +0800 Subject: [PATCH 1/3] fix minsnr implementation for v-prediction case --- examples/controlnet/train_controlnet_flax.py | 11 ++++++---- .../train_text_to_image_decoder.py | 22 ++++++++++++++----- .../train_text_to_image_lora_decoder.py | 22 ++++++++++++++----- .../train_text_to_image_lora_prior.py | 22 ++++++++++++++----- .../train_text_to_image_prior.py | 22 ++++++++++++++----- .../text_to_image/train_text_to_image.py | 22 ++++++++++++++----- examples/text_to_image/train_text_to_image.py | 22 ++++++++++++++----- .../text_to_image/train_text_to_image_lora.py | 22 ++++++++++++++----- .../train_text_to_image_lora_sdxl.py | 22 ++++++++++++++----- .../text_to_image/train_text_to_image_sdxl.py | 22 ++++++++++++++----- 10 files changed, 151 insertions(+), 58 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index b3c09325fc4d..2ff0c849d25e 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -907,10 +907,13 @@ def compute_loss(params, minibatch, sample_rng): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + if noise_scheduler.config.prediction_type == "epsilon": + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + elif noise_scheduler.config.prediction_type == "v_prediction": + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / (snr + 1) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index a4017c85e1b5..83b6450dd4d3 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -781,12 +781,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 90cf540c6425..60d0b6525321 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -631,12 +631,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index b64986ecf5ae..8ceda8a4cfd4 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -664,12 +664,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index a6855abcee75..a2c54599b11c 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -811,12 +811,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index f7100788cde2..8763947ba0fd 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -848,12 +848,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c5371c95469e..9fe50c943f1d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -936,12 +936,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 2efbaf298d2e..3e732160f067 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -752,12 +752,22 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index d95fcbbba033..8cb6f46db228 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1046,12 +1046,22 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 5c024f4080ae..986001ae904c 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1081,12 +1081,22 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = ( + torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / (snr + 1) + ) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights From b91f77e0e6309d38bcbd2c7fd8fe736a5b612a34 Mon Sep 17 00:00:00 2001 From: thuliu-yt16 Date: Sun, 24 Dec 2023 22:25:31 +0800 Subject: [PATCH 2/3] format code --- examples/controlnet/train_controlnet_flax.py | 4 +++- .../text_to_image/train_text_to_image_decoder.py | 14 ++++---------- .../train_text_to_image_lora_decoder.py | 14 ++++---------- .../train_text_to_image_lora_prior.py | 14 ++++---------- .../text_to_image/train_text_to_image_prior.py | 14 ++++---------- .../text_to_image/train_text_to_image.py | 14 ++++---------- examples/text_to_image/train_text_to_image.py | 14 ++++---------- examples/text_to_image/train_text_to_image_lora.py | 14 ++++---------- .../text_to_image/train_text_to_image_lora_sdxl.py | 14 ++++---------- examples/text_to_image/train_text_to_image_sdxl.py | 14 ++++---------- 10 files changed, 39 insertions(+), 91 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 2ff0c849d25e..8c761c49ef62 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -910,7 +910,9 @@ def compute_loss(params, minibatch, sample_rng): if noise_scheduler.config.prediction_type == "epsilon": snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr elif noise_scheduler.config.prediction_type == "v_prediction": - snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / (snr + 1) + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / ( + snr + 1 + ) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 83b6450dd4d3..d5a9394bb5d9 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -783,18 +783,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 60d0b6525321..bd153d5dc83d 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -633,18 +633,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index 8ceda8a4cfd4..e2c0f5883d5a 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -666,18 +666,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index a2c54599b11c..08ce1e0a5c0e 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -813,18 +813,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 8763947ba0fd..2a19b8ce8d28 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -850,18 +850,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 9fe50c943f1d..7a2712b967ba 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -938,18 +938,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 3e732160f067..4878023de55c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -754,18 +754,12 @@ def collate_fn(examples): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 8cb6f46db228..4701a647b9b5 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1048,18 +1048,12 @@ def compute_time_ids(original_size, crops_coords_top_left): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 986001ae904c..601a17ef5c1a 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1083,18 +1083,12 @@ def compute_time_ids(original_size, crops_coords_top_left): snr = compute_snr(noise_scheduler, timesteps) if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = ( - torch.stack( - [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 - ).min(dim=1)[0] - / (snr + 1) - ) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] / (snr + 1) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") From 0b27ab8b4fa8123e4fc8cdfb2a79ca39e805b9b8 Mon Sep 17 00:00:00 2001 From: thuliu-yt16 Date: Thu, 28 Dec 2023 14:28:51 +0800 Subject: [PATCH 3/3] always compute snr when snr_gamma is specified --- examples/controlnet/train_controlnet_flax.py | 9 +++------ .../text_to_image/train_text_to_image_decoder.py | 13 +++++-------- .../train_text_to_image_lora_decoder.py | 13 +++++-------- .../text_to_image/train_text_to_image_lora_prior.py | 13 +++++-------- .../text_to_image/train_text_to_image_prior.py | 13 +++++-------- .../text_to_image/train_text_to_image.py | 13 +++++-------- examples/text_to_image/train_text_to_image.py | 13 +++++-------- examples/text_to_image/train_text_to_image_lora.py | 13 +++++-------- .../text_to_image/train_text_to_image_lora_sdxl.py | 13 +++++-------- examples/text_to_image/train_text_to_image_sdxl.py | 13 +++++-------- 10 files changed, 48 insertions(+), 78 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 8c761c49ef62..ea25f3748a93 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -907,14 +907,11 @@ def compute_loss(params, minibatch, sample_rng): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) if noise_scheduler.config.prediction_type == "epsilon": - snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + snr_loss_weights = snr_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / ( - snr + 1 - ) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + snr_loss_weights = snr_loss_weights / (snr + 1) loss = loss * snr_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index d5a9394bb5d9..cd7f28b5324f 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -781,16 +781,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index bd153d5dc83d..96822b463911 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -631,16 +631,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index e2c0f5883d5a..04ac50768111 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -664,16 +664,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index 08ce1e0a5c0e..e2c3ef0f1755 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -811,16 +811,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 2a19b8ce8d28..287aecd40027 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -848,16 +848,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 7a2712b967ba..864e0d86d173 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -936,16 +936,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 4878023de55c..26ff76818727 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -752,16 +752,13 @@ def collate_fn(examples): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 4701a647b9b5..9453edcc6e9f 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1046,16 +1046,13 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 601a17ef5c1a..4a2a577ae2b1 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1081,16 +1081,13 @@ def compute_time_ids(original_size, crops_coords_top_left): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] if noise_scheduler.config.prediction_type == "epsilon": - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) + mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] / (snr + 1) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + mse_loss_weights = mse_loss_weights / (snr + 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights