Skip to content

Commit 32bd473

Browse files
committed
fix for dreambooth lora sdxl
1 parent b6de725 commit 32bd473

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,12 @@ def main(args):
10241024
text_encoder_one.add_adapter(text_lora_config)
10251025
text_encoder_two.add_adapter(text_lora_config)
10261026

1027+
# Make sure the trainable params are in float32.
1028+
for model in [unet, text_encoder_one, text_encoder_two]:
1029+
for param in model.parameters():
1030+
if param.requires_grad:
1031+
param.data = param.to(torch.float32)
1032+
10271033
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
10281034
def save_model_hook(models, weights, output_dir):
10291035
if accelerator.is_main_process:

0 commit comments

Comments
 (0)