diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 1b75402c3550..367a3422de33 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -743,7 +743,7 @@ def main(args): ) temp_pipeline._modify_text_encoder(text_lora_attn_procs) text_encoder = temp_pipeline.text_encoder - accelerator.register_for_checkpointing(unet_lora_layers) + accelerator.register_for_checkpointing(text_encoder_lora_layers) del temp_pipeline if args.scale_lr: