diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 243e433a5b83..6955c9f60a2b 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -143,7 +143,8 @@ def transforms(examples): loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) - accelerator.clip_grad_norm_(model.parameters(), 1.0) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() if args.use_ema: