diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0522b3fb8f8a..9bec871c31ad 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -11,6 +11,8 @@ import torch.utils.checkpoint from torch.utils.data import Dataset +from torch.cuda.amp import autocast + from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -629,7 +631,8 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon":