From 362dcb7dd21bc2f8cff17407ef3d7a89e7e5c179 Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Tue, 17 Jun 2025 18:23:44 +0530 Subject: [PATCH 1/4] Update train_dreambooth_lora_sdxl.py --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c3dfc923f0d6..d100d66e183d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -62,6 +62,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( + _collate_lora_metadata, check_min_version, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, @@ -659,6 +660,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") @@ -1205,7 +1212,7 @@ def main(args): def get_lora_config(rank, dropout, use_dora, target_modules): base_config = { "r": rank, - "lora_alpha": rank, + "lora_alpha"=args.lora_alpha, "lora_dropout": dropout, "init_lora_weights": "gaussian", "target_modules": target_modules, @@ -1224,6 +1231,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules): unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] unet_lora_config = get_lora_config( rank=args.rank, + lora_alpha=args.lora_alpha, dropout=args.lora_dropout, use_dora=args.use_dora, target_modules=unet_target_modules, @@ -1256,10 +1264,12 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None + modules_to_save = {} for model in models: if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + modules_to_save["transformer"] = model elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1279,6 +1289,7 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) def load_model_hook(models, input_dir): @@ -1945,6 +1956,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + modules_to_save = {} unet = unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) @@ -1967,6 +1979,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): unet_lora_layers=unet_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, + **_collate_lora_metadata(modules_to_save), ) if args.output_kohya_format: lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") From 7fd96a31f8354878e812d80861fb76aec3c878fa Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Wed, 18 Jun 2025 17:13:26 +0530 Subject: [PATCH 2/4] Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index d100d66e183d..b9cc7c1bf686 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1212,7 +1212,7 @@ def main(args): def get_lora_config(rank, dropout, use_dora, target_modules): base_config = { "r": rank, - "lora_alpha"=args.lora_alpha, + "lora_alpha":lora_alpha, "lora_dropout": dropout, "init_lora_weights": "gaussian", "target_modules": target_modules, From 74adabd10bc67f29ff1275b2aa5f42f0b5bd0f68 Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Wed, 18 Jun 2025 17:13:35 +0530 Subject: [PATCH 3/4] Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index b9cc7c1bf686..b68439065dfa 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1209,7 +1209,7 @@ def main(args): text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() - def get_lora_config(rank, dropout, use_dora, target_modules): + def get_lora_config(rank, lora_alpha, dropout, use_dora, target_modules): base_config = { "r": rank, "lora_alpha":lora_alpha, From a54ecb1ad1678834ebe4556f1ffd4818bcfe3d7d Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Fri, 20 Jun 2025 10:23:59 +0530 Subject: [PATCH 4/4] Update train_dreambooth_lora_sdxl.py --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index b68439065dfa..51db6bb99f22 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1244,6 +1244,7 @@ def get_lora_config(rank, lora_alpha, dropout, use_dora, target_modules): text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_lora_config = get_lora_config( rank=args.rank, + lora_alpha=args.lora_alpha, dropout=args.lora_dropout, use_dora=args.use_dora, target_modules=text_target_modules,