diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index f3621ac61309..4b388d92a195 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -408,4 +408,8 @@ You can then start your training from this saved checkpoint with ```bash --controlnet_model_name_or_path="./control_out/500" -``` \ No newline at end of file +``` + +We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. + +We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). \ No newline at end of file diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 292b665a8a42..224a50bb7fbe 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -289,6 +289,13 @@ def parse_args(): ' "constant", "constant_with_warmup"]' ), ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -328,11 +335,8 @@ def parse_args(): parser.add_argument( "--report_to", type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), + default="wandb", + help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'), ) parser.add_argument( "--mixed_precision", @@ -442,6 +446,7 @@ def parse_args(): " `args.validation_prompt` and logging the images." ), ) + parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams).")) parser.add_argument( "--tracker_project_name", type=str, @@ -668,6 +673,7 @@ def main(): # wandb init if jax.process_index() == 0 and args.report_to == "wandb": wandb.init( + entity=args.wandb_entity, project=args.tracker_project_name, job_type="train", config=args, @@ -806,6 +812,20 @@ def main(): validation_rng, train_rngs = jax.random.split(rng) train_rngs = jax.random.split(train_rngs, jax.local_device_count()) + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler_state.common.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + alpha = sqrt_alphas_cumprod[timesteps] + sigma = sqrt_one_minus_alphas_cumprod[timesteps] + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng): # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1 if args.gradient_accumulation_steps > 1: @@ -876,6 +896,12 @@ def compute_loss(params, minibatch, sample_rng): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = (target - model_pred) ** 2 + + if args.snr_gamma is not None: + snr = jnp.array(compute_snr(timesteps)) + snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + loss = loss * snr_loss_weights + loss = loss.mean() return loss