Skip to content

Commit 1ac07d8

Browse files
authored
[Training examples] Follow up of #6306 (#6346)
* add to dreambooth lora. * add: t2i lora. * add: sdxl t2i lora. * style * lcm lora sdxl. * unwrap * fix: enable_adapters().
1 parent 1fff527 commit 1ac07d8

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
UNet2DConditionModel,
5252
)
5353
from diffusers.optimization import get_scheduler
54-
from diffusers.utils import check_min_version, is_wandb_available
54+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5555
from diffusers.utils.import_utils import is_xformers_available
5656

5757

@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
113113
if unet is None:
114114
raise ValueError("Must provide a `unet` when doing intermediate validation.")
115115
unet = accelerator.unwrap_model(unet)
116-
state_dict = get_peft_model_state_dict(unet)
116+
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
117117
to_load = state_dict
118118
else:
119119
to_load = args.output_dir
@@ -819,7 +819,7 @@ def save_model_hook(models, weights, output_dir):
819819
unet_ = accelerator.unwrap_model(unet)
820820
# also save the checkpoints in native `diffusers` format so that it can be easily
821821
# be independently loaded via `load_lora_weights()`.
822-
state_dict = get_peft_model_state_dict(unet_)
822+
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
823823
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
824824

825825
for _, model in enumerate(models):
@@ -1184,7 +1184,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
11841184
# solver timestep.
11851185

11861186
# With the adapters disabled, the `unet` is the regular teacher model.
1187-
unet.disable_adapters()
1187+
accelerator.unwrap_model(unet).disable_adapters()
11881188
with torch.no_grad():
11891189
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
11901190
cond_teacher_output = unet(
@@ -1248,7 +1248,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
12481248
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
12491249

12501250
# re-enable unet adapters to turn the `unet` into a student unet.
1251-
unet.enable_adapters()
1251+
accelerator.unwrap_model(unet).enable_adapters()
12521252

12531253
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
12541254
# Note that we do not use a separate target network for LCM-LoRA distillation.
@@ -1332,7 +1332,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
13321332
accelerator.wait_for_everyone()
13331333
if accelerator.is_main_process:
13341334
unet = accelerator.unwrap_model(unet)
1335-
unet_lora_state_dict = get_peft_model_state_dict(unet)
1335+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
13361336
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
13371337

13381338
if args.push_to_hub:

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
)
5555
from diffusers.loaders import LoraLoaderMixin
5656
from diffusers.optimization import get_scheduler
57-
from diffusers.utils import check_min_version, is_wandb_available
57+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5858
from diffusers.utils.import_utils import is_xformers_available
5959

6060

@@ -853,9 +853,11 @@ def save_model_hook(models, weights, output_dir):
853853

854854
for model in models:
855855
if isinstance(model, type(accelerator.unwrap_model(unet))):
856-
unet_lora_layers_to_save = get_peft_model_state_dict(model)
856+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
857857
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
858-
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
858+
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
859+
get_peft_model_state_dict(model)
860+
)
859861
else:
860862
raise ValueError(f"unexpected save model: {model.__class__}")
861863

@@ -1285,11 +1287,11 @@ def compute_text_embeddings(prompt):
12851287
unet = accelerator.unwrap_model(unet)
12861288
unet = unet.to(torch.float32)
12871289

1288-
unet_lora_state_dict = get_peft_model_state_dict(unet)
1290+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
12891291

12901292
if args.train_text_encoder:
12911293
text_encoder = accelerator.unwrap_model(text_encoder)
1292-
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
1294+
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
12931295
else:
12941296
text_encoder_state_dict = None
12951297

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
4545
from diffusers.optimization import get_scheduler
4646
from diffusers.training_utils import compute_snr
47-
from diffusers.utils import check_min_version, is_wandb_available
47+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
4848
from diffusers.utils.import_utils import is_xformers_available
4949

5050

@@ -809,7 +809,9 @@ def collate_fn(examples):
809809
accelerator.save_state(save_path)
810810

811811
unwrapped_unet = accelerator.unwrap_model(unet)
812-
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
812+
unet_lora_state_dict = convert_state_dict_to_diffusers(
813+
get_peft_model_state_dict(unwrapped_unet)
814+
)
813815

814816
StableDiffusionPipeline.save_lora_weights(
815817
save_directory=save_path,
@@ -876,7 +878,7 @@ def collate_fn(examples):
876878
unet = unet.to(torch.float32)
877879

878880
unwrapped_unet = accelerator.unwrap_model(unet)
879-
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
881+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
880882
StableDiffusionPipeline.save_lora_weights(
881883
save_directory=args.output_dir,
882884
unet_lora_layers=unet_lora_state_dict,

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from diffusers.loaders import LoraLoaderMixin
5353
from diffusers.optimization import get_scheduler
5454
from diffusers.training_utils import compute_snr
55-
from diffusers.utils import check_min_version, is_wandb_available
55+
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
5656
from diffusers.utils.import_utils import is_xformers_available
5757

5858

@@ -651,11 +651,15 @@ def save_model_hook(models, weights, output_dir):
651651

652652
for model in models:
653653
if isinstance(model, type(accelerator.unwrap_model(unet))):
654-
unet_lora_layers_to_save = get_peft_model_state_dict(model)
654+
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
655655
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
656-
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
656+
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
657+
get_peft_model_state_dict(model)
658+
)
657659
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
658-
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
660+
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
661+
get_peft_model_state_dict(model)
662+
)
659663
else:
660664
raise ValueError(f"unexpected save model: {model.__class__}")
661665

@@ -1160,14 +1164,14 @@ def compute_time_ids(original_size, crops_coords_top_left):
11601164
accelerator.wait_for_everyone()
11611165
if accelerator.is_main_process:
11621166
unet = accelerator.unwrap_model(unet)
1163-
unet_lora_state_dict = get_peft_model_state_dict(unet)
1167+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
11641168

11651169
if args.train_text_encoder:
11661170
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
11671171
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
11681172

1169-
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
1170-
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
1173+
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
1174+
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
11711175
else:
11721176
text_encoder_lora_layers = None
11731177
text_encoder_2_lora_layers = None

0 commit comments

Comments
 (0)