From 4fc3887509f1b8c9eda30643c2d41985ca5ef58e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Apr 2023 03:10:09 +0000 Subject: [PATCH 1/3] add wandb team and min-snr loss --- examples/controlnet/train_controlnet_flax.py | 36 ++++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 292b665a8a42..992a0edbac7f 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -289,6 +289,13 @@ def parse_args(): ' "constant", "constant_with_warmup"]' ), ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -328,10 +335,9 @@ def parse_args(): parser.add_argument( "--report_to", type=str, - default="tensorboard", + default="wandb", help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + 'The integration to report the results and logs to. Currently only supported platforms are `"wandb"`' ), ) parser.add_argument( @@ -442,6 +448,12 @@ def parse_args(): " `args.validation_prompt` and logging the images." ), ) + parser.agg_argument( + "--wandb_entity", + type=str, + default=None, + help=("The wandb entity to use (for teams).") + ) parser.add_argument( "--tracker_project_name", type=str, @@ -668,6 +680,7 @@ def main(): # wandb init if jax.process_index() == 0 and args.report_to == "wandb": wandb.init( + entity=args.wandb_entity, project=args.tracker_project_name, job_type="train", config=args, @@ -806,6 +819,17 @@ def main(): validation_rng, train_rngs = jax.random.split(rng) train_rngs = jax.random.split(train_rngs, jax.local_device_count()) + def compute_snr(timesteps): + alphas_cumprod = noise_scheduler_state.common.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + alpha = sqrt_alphas_cumprod[timesteps] + sigma = sqrt_one_minus_alphas_cumprod[timesteps] + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 if args.gradient_accumulation_steps > 1: @@ -876,6 +900,12 @@ def compute_loss(params, minibatch, sample_rng): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = (target - model_pred) ** 2 + + if args.snr_gamma is not None: + snr = jnp.array(compute_snr(timesteps)) + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + loss = loss * snr_loss_weights + loss = loss.mean() return loss From f2ecd49f7e4f453b24bf9dbd4c4a94d0be625917 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Apr 2023 03:11:08 +0000 Subject: [PATCH 2/3] make style --- examples/controlnet/train_controlnet_flax.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 992a0edbac7f..d81ef8628e85 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -336,9 +336,7 @@ def parse_args(): "--report_to", type=str, default="wandb", - help=( - 'The integration to report the results and logs to. Currently only supported platforms are `"wandb"`' - ), + help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'), ) parser.add_argument( "--mixed_precision", @@ -448,12 +446,7 @@ def parse_args(): " `args.validation_prompt` and logging the images." ), ) - parser.agg_argument( - "--wandb_entity", - type=str, - default=None, - help=("The wandb entity to use (for teams).") - ) + parser.agg_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams).")) parser.add_argument( "--tracker_project_name", type=str, @@ -905,7 +898,7 @@ def compute_loss(params, minibatch, sample_rng): snr = jnp.array(compute_snr(timesteps)) snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr loss = loss * snr_loss_weights - + loss = loss.mean() return loss From 70a98723753228b449deb4d01beca4407e3d7496 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Apr 2023 00:17:35 +0000 Subject: [PATCH 3/3] apply feedbacks --- examples/controlnet/README.md | 6 +++++- examples/controlnet/train_controlnet_flax.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index f3621ac61309..4b388d92a195 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -408,4 +408,8 @@ You can then start your training from this saved checkpoint with ```bash --controlnet_model_name_or_path="./control_out/500" -``` \ No newline at end of file +``` + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. + +We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). \ No newline at end of file diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d81ef8628e85..224a50bb7fbe 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -446,7 +446,7 @@ def parse_args(): " `args.validation_prompt` and logging the images." ), ) - parser.agg_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams).")) + parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams).")) parser.add_argument( "--tracker_project_name", type=str, @@ -813,6 +813,9 @@ def main(): train_rngs = jax.random.split(train_rngs, jax.local_device_count()) def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ alphas_cumprod = noise_scheduler_state.common.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5