diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index a995eb3043dc..2cc2ab79db95 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1399,8 +1399,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_two.train() # set top parameter requires_grad = True for gradient checkpointing works - text_encoder_one.text_model.embeddings.requires_grad_(True) - text_encoder_two.text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet):