Skip to content

Commit 8ac462b

Browse files
committed
make the type-casting conditional.
1 parent 3aed05c commit 8ac462b

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,14 @@ def main(args):
10251025
text_encoder_two.add_adapter(text_lora_config)
10261026

10271027
# Make sure the trainable params are in float32.
1028-
for model in [unet, text_encoder_one, text_encoder_two]:
1029-
for param in model.parameters():
1030-
if param.requires_grad:
1031-
param.data = param.to(torch.float32)
1028+
if args.mixed_precision == "fp16":
1029+
models = [unet]
1030+
if args.train_text_encoder:
1031+
models.extend([text_encoder_one, text_encoder_two])
1032+
for model in models:
1033+
for param in model.parameters():
1034+
if param.requires_grad:
1035+
param.data = param.to(torch.float32)
10321036

10331037
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
10341038
def save_model_hook(models, weights, output_dir):

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,10 @@ def main():
495495

496496
# Add adapter and make sure the trainable params are in float32.
497497
unet.add_adapter(unet_lora_config)
498-
for param in unet.parameters():
499-
if param.requires_grad:
500-
param.data = param.to(torch.float32)
498+
if args.mixed_precision == "fp16":
499+
for param in unet.parameters():
500+
if param.requires_grad:
501+
param.data = param.to(torch.float32)
501502

502503
if args.enable_xformers_memory_efficient_attention:
503504
if is_xformers_available():

0 commit comments

Comments
 (0)