Skip to content

Commit b7d2395

Browse files
committed
fix checkpointing.
1 parent f4ed5e8 commit b7d2395

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -724,11 +724,15 @@ def load_model_hook(models, input_dir):
724724

725725
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
726726
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
727+
728+
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
727729
LoraLoaderMixin.load_lora_into_text_encoder(
728-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
730+
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
729731
)
732+
733+
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
730734
LoraLoaderMixin.load_lora_into_text_encoder(
731-
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
735+
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
732736
)
733737

734738
accelerator.register_save_state_pre_hook(save_model_hook)

0 commit comments

Comments
 (0)