Skip to content

Commit 3b5062c

Browse files
author
mhh001
committed
Fixed the bug related to saving DeepSpeed models.
1 parent 5ca062e commit 3b5062c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,21 +652,22 @@ def save_model_hook(models, weights, output_dir):
652652
text_encoder_two_lora_layers_to_save = None
653653

654654
for model in models:
655-
if isinstance(model, type(unwrap_model(unet))):
655+
if isinstance(unwrap_model(model), type(unwrap_model(unet))):
656656
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
657-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
657+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
658658
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
659659
get_peft_model_state_dict(model)
660660
)
661-
elif isinstance(model, type(unwrap_model(text_encoder_two))):
661+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
662662
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
663663
get_peft_model_state_dict(model)
664664
)
665665
else:
666666
raise ValueError(f"unexpected save model: {model.__class__}")
667667

668668
# make sure to pop weight so that corresponding model is not saved again
669-
weights.pop()
669+
if weights:
670+
weights.pop()
670671

671672
StableDiffusionXLPipeline.save_lora_weights(
672673
output_dir,

0 commit comments

Comments
 (0)