diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e621b3caee3..d63fb80aee33 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1183,25 +1183,36 @@ def main(args): text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() + def get_lora_config(rank, use_dora, target_modules): + base_config = { + "r": rank, + "lora_alpha": rank, + "init_lora_weights": "gaussian", + "target_modules": target_modules, + } + + if use_dora: + base_config["use_dora"] = True + + return LoraConfig(**base_config) + # now we will add new LoRA weights to the attention layers - unet_lora_config = LoraConfig( - r=args.rank, + unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + unet_lora_config = get_lora_config( + rank=args.rank, use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=unet_target_modules ) unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, + text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + text_lora_config = get_lora_config( + rank=args.rank, use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + target_modules=text_target_modules ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config)