Skip to content

Commit dcfa6e1

Browse files
authored
add Min-SNR loss to Controlnet flax train script (#3016)
* add wandb team and min-snr loss * make style * apply feedbacks
1 parent 1c96f82 commit dcfa6e1

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

examples/controlnet/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,4 +408,8 @@ You can then start your training from this saved checkpoint with
408408

409409
```bash
410410
--controlnet_model_name_or_path="./control_out/500"
411-
```
411+
```
412+
413+
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
414+
415+
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).

examples/controlnet/train_controlnet_flax.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,13 @@ def parse_args():
289289
' "constant", "constant_with_warmup"]'
290290
),
291291
)
292+
parser.add_argument(
293+
"--snr_gamma",
294+
type=float,
295+
default=None,
296+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
297+
"More details here: https://arxiv.org/abs/2303.09556.",
298+
)
292299
parser.add_argument(
293300
"--dataloader_num_workers",
294301
type=int,
@@ -328,11 +335,8 @@ def parse_args():
328335
parser.add_argument(
329336
"--report_to",
330337
type=str,
331-
default="tensorboard",
332-
help=(
333-
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
334-
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
335-
),
338+
default="wandb",
339+
help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
336340
)
337341
parser.add_argument(
338342
"--mixed_precision",
@@ -442,6 +446,7 @@ def parse_args():
442446
" `args.validation_prompt` and logging the images."
443447
),
444448
)
449+
parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
445450
parser.add_argument(
446451
"--tracker_project_name",
447452
type=str,
@@ -668,6 +673,7 @@ def main():
668673
# wandb init
669674
if jax.process_index() == 0 and args.report_to == "wandb":
670675
wandb.init(
676+
entity=args.wandb_entity,
671677
project=args.tracker_project_name,
672678
job_type="train",
673679
config=args,
@@ -806,6 +812,20 @@ def main():
806812
validation_rng, train_rngs = jax.random.split(rng)
807813
train_rngs = jax.random.split(train_rngs, jax.local_device_count())
808814

815+
def compute_snr(timesteps):
816+
"""
817+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
818+
"""
819+
alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
820+
sqrt_alphas_cumprod = alphas_cumprod**0.5
821+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
822+
823+
alpha = sqrt_alphas_cumprod[timesteps]
824+
sigma = sqrt_one_minus_alphas_cumprod[timesteps]
825+
# Compute SNR.
826+
snr = (alpha / sigma) ** 2
827+
return snr
828+
809829
def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
810830
# reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
811831
if args.gradient_accumulation_steps > 1:
@@ -876,6 +896,12 @@ def compute_loss(params, minibatch, sample_rng):
876896
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
877897

878898
loss = (target - model_pred) ** 2
899+
900+
if args.snr_gamma is not None:
901+
snr = jnp.array(compute_snr(timesteps))
902+
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
903+
loss = loss * snr_loss_weights
904+
879905
loss = loss.mean()
880906

881907
return loss

0 commit comments

Comments
 (0)