Skip to content

Commit 4a98d6e

Browse files
authored
Update train_text_to_image_lora.py (#2795)
1 parent b94880e commit 4a98d6e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/research_projects/lora/train_text_to_image_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,9 @@ def main():
542542
lora_layers = AttnProcsLayers(unet.attn_processors)
543543

544544
# Move unet, vae and text_encoder to device and cast to weight_dtype
545-
unet.to(accelerator.device, dtype=weight_dtype)
546545
vae.to(accelerator.device, dtype=weight_dtype)
547-
text_encoder.to(accelerator.device, dtype=weight_dtype)
546+
if not args.train_text_encoder:
547+
text_encoder.to(accelerator.device, dtype=weight_dtype)
548548

549549
if args.enable_xformers_memory_efficient_attention:
550550
if is_xformers_available():

0 commit comments

Comments
 (0)