Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,25 +1183,36 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()

def get_lora_config(rank, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
}

if use_dora:
base_config["use_dora"] = True
Comment on lines +1194 to +1195
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail for lower version of peft where use_dora isn't present in the call args.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out! Could you clarify which version of peft lacks use_dora in the call args?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):


return LoraConfig(**base_config)

# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(
rank=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
target_modules=unet_target_modules
)
unet.add_adapter(unet_lora_config)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(
rank=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
target_modules=text_target_modules
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
Expand Down