We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b6de725 commit 32bd473Copy full SHA for 32bd473
examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -1024,6 +1024,12 @@ def main(args):
1024
text_encoder_one.add_adapter(text_lora_config)
1025
text_encoder_two.add_adapter(text_lora_config)
1026
1027
+ # Make sure the trainable params are in float32.
1028
+ for model in [unet, text_encoder_one, text_encoder_two]:
1029
+ for param in model.parameters():
1030
+ if param.requires_grad:
1031
+ param.data = param.to(torch.float32)
1032
+
1033
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1034
def save_model_hook(models, weights, output_dir):
1035
if accelerator.is_main_process:
0 commit comments