diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e621b3caee3..9cd321f6d055 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -67,6 +67,7 @@ convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_unet_state_dict_to_peft, + is_peft_version, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card @@ -1183,26 +1184,33 @@ def main(args): text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() + def get_lora_config(rank, use_dora, target_modules): + base_config = { + "r": rank, + "lora_alpha": rank, + "init_lora_weights": "gaussian", + "target_modules": target_modules, + } + if use_dora: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + base_config["use_dora"] = True + + return LoraConfig(**base_config) + # now we will add new LoRA weights to the attention layers - unet_lora_config = LoraConfig( - r=args.rank, - use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) + unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules) unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) + text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config)