diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c3dfc923f0d6..51db6bb99f22 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -62,6 +62,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( + _collate_lora_metadata, check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, @@ -659,6 +660,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") @@ -1202,10 +1209,10 @@ def main(args): text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() - def get_lora_config(rank, dropout, use_dora, target_modules): + def get_lora_config(rank, lora_alpha, dropout, use_dora, target_modules): base_config = { "r": rank, - "lora_alpha": rank, + "lora_alpha":lora_alpha, "lora_dropout": dropout, "init_lora_weights": "gaussian", "target_modules": target_modules, @@ -1224,6 +1231,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules): unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] unet_lora_config = get_lora_config( rank=args.rank, + lora_alpha=args.lora_alpha, dropout=args.lora_dropout, use_dora=args.use_dora, target_modules=unet_target_modules, @@ -1236,6 +1244,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules): text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_lora_config = get_lora_config( rank=args.rank, + lora_alpha=args.lora_alpha, dropout=args.lora_dropout, use_dora=args.use_dora, target_modules=text_target_modules, @@ -1256,10 +1265,12 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None + modules_to_save = {} for model in models: if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + modules_to_save["transformer"] = model elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1279,6 +1290,7 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) def load_model_hook(models, input_dir): @@ -1945,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + modules_to_save = {} unet = unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) @@ -1967,6 +1980,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, + **_collate_lora_metadata(modules_to_save), ) if args.output_kohya_format: lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")