diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c030c59693c3..df5c5cda44e3 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -479,8 +479,7 @@ def main(): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) + # Move vae and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -536,6 +535,8 @@ def main(): unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + # Move unet and lora to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers