Skip to content

Commit 1828f82

Browse files
committed
saving and loading weights without AttnProcLayers class
1 parent a0bc6e8 commit 1828f82

File tree

3 files changed

+109
-55
lines changed

3 files changed

+109
-55
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@
4949
StableDiffusionPipeline,
5050
UNet2DConditionModel,
5151
)
52-
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, text_encoder_attn_modules
52+
from diffusers.loaders import (
53+
LORA_WEIGHT_NAME,
54+
TEXT_ENCODER_NAME,
55+
UNET_NAME,
56+
LoraLoaderMixin,
57+
text_encoder_attn_modules,
58+
text_encoder_lora_state_dict,
59+
)
5360
from diffusers.models.attention_processor import (
5461
AttnAddedKVProcessor,
5562
AttnAddedKVProcessor2_0,
@@ -832,6 +839,7 @@ def main(args):
832839

833840
# Set correct lora layers
834841
unet_lora_attn_procs = {}
842+
unet_lora_parameters = []
835843
for name, attn_processor in unet.attn_processors.items():
836844
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
837845
if name.startswith("mid_block"):
@@ -849,18 +857,17 @@ def main(args):
849857
lora_attn_processor_class = (
850858
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
851859
)
852-
unet_lora_attn_procs[name] = lora_attn_processor_class(
853-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
854-
)
860+
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
861+
unet_lora_attn_procs[name] = module
862+
unet_lora_parameters.append(module.parameters())
855863

856864
unet.set_attn_processor(unet_lora_attn_procs)
857-
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
858865

859866
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
860867
# So, instead, we monkey-patch the forward calls of its attention-blocks.
861-
text_encoder_lora_layers = None
862868
if args.train_text_encoder:
863869
text_lora_attn_procs = {}
870+
text_lora_parameters = []
864871

865872
for name, module in text_encoder_attn_modules(text_encoder):
866873
if isinstance(text_encoder, CLIPTextModel):
@@ -872,9 +879,10 @@ def main(args):
872879
else:
873880
raise ValueError(f"{text_encoder.__class__.__name__} does not support LoRA training")
874881

875-
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, inner_dim=inner_dim)
882+
module = LoRAAttnProcessor(hidden_size=hidden_size, inner_dim=inner_dim)
883+
text_lora_attn_procs[name] = module
884+
text_lora_parameters.append(module.parameters())
876885

877-
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
878886
LoraLoaderMixin._modify_text_encoder(text_lora_attn_procs, text_encoder)
879887

880888
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -884,23 +892,13 @@ def save_model_hook(models, weights, output_dir):
884892
unet_lora_layers_to_save = None
885893
text_encoder_lora_layers_to_save = None
886894

887-
if args.train_text_encoder:
888-
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
889-
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
890-
891895
for model in models:
892-
state_dict = model.state_dict()
893-
894-
if (
895-
text_encoder_lora_layers is not None
896-
and text_encoder_keys is not None
897-
and state_dict.keys() == text_encoder_keys
898-
):
899-
# text encoder
900-
text_encoder_lora_layers_to_save = state_dict
901-
elif state_dict.keys() == unet_keys:
902-
# unet
903-
unet_lora_layers_to_save = state_dict
896+
if isinstance(model, type(accelerator.unwrap_model(unet))):
897+
unet_lora_layers_to_save = model.attn_processors_state_dict
898+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
899+
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
900+
else:
901+
raise ValueError(f"unexpected save model: {model.__class__}")
904902

905903
# make sure to pop weight so that corresponding model is not saved again
906904
weights.pop()
@@ -912,27 +910,23 @@ def save_model_hook(models, weights, output_dir):
912910
)
913911

914912
def load_model_hook(models, input_dir):
915-
# Note we DON'T pass the unet and text encoder here an purpose
916-
# so that the we don't accidentally override the LoRA layers of
917-
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
918-
# with new torch.nn.Modules / weights. We simply use the pipeline class as
919-
# an easy way to load the lora checkpoints
920-
temp_pipeline = DiffusionPipeline.from_pretrained(
921-
args.pretrained_model_name_or_path,
922-
revision=args.revision,
923-
torch_dtype=weight_dtype,
924-
)
925-
temp_pipeline.load_lora_weights(input_dir)
926-
927-
# load lora weights into models
928-
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
929-
if len(models) > 1:
930-
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
913+
lora_weights = torch.load(os.path.join(input_dir, LORA_WEIGHT_NAME))
914+
unet_weights = {}
915+
text_encoder_weights = {}
916+
917+
for k, v in lora_weights.items():
918+
model, *k = k.split(".")
919+
k = ".".join(k)
920+
921+
if model == UNET_NAME:
922+
unet_weights[k] = v
923+
elif model == TEXT_ENCODER_NAME:
924+
text_encoder_weights[k] = v
925+
else:
926+
raise ValueError(f"unknown model name {model}")
931927

932-
# delete temporary pipeline and pop models
933-
del temp_pipeline
934-
for _ in range(len(models)):
935-
models.pop()
928+
unet.load_state_dict(unet_weights, strict=False)
929+
text_encoder.load_state_dict(text_encoder_weights, strict=False)
936930

937931
accelerator.register_save_state_pre_hook(save_model_hook)
938932
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -962,9 +956,9 @@ def load_model_hook(models, input_dir):
962956

963957
# Optimizer creation
964958
params_to_optimize = (
965-
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
959+
itertools.chain(unet_lora_parameters, text_lora_parameters)
966960
if args.train_text_encoder
967-
else unet_lora_layers.parameters()
961+
else unet_lora_parameters
968962
)
969963
optimizer = optimizer_class(
970964
params_to_optimize,
@@ -1053,12 +1047,12 @@ def compute_text_embeddings(prompt):
10531047

10541048
# Prepare everything with our `accelerator`.
10551049
if args.train_text_encoder:
1056-
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1057-
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
1050+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1051+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
10581052
)
10591053
else:
1060-
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1061-
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
1054+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1055+
unet, optimizer, train_dataloader, lr_scheduler
10621056
)
10631057

10641058
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1207,9 +1201,9 @@ def compute_text_embeddings(prompt):
12071201
accelerator.backward(loss)
12081202
if accelerator.sync_gradients:
12091203
params_to_clip = (
1210-
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
1204+
itertools.chain(unet_lora_parameters, text_lora_parameters)
12111205
if args.train_text_encoder
1212-
else unet_lora_layers.parameters()
1206+
else unet_lora_parameters
12131207
)
12141208
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
12151209
optimizer.step()
@@ -1309,12 +1303,14 @@ def compute_text_embeddings(prompt):
13091303
# Save the lora layers
13101304
accelerator.wait_for_everyone()
13111305
if accelerator.is_main_process:
1306+
unet = accelerator.unwrap_model(unet)
13121307
unet = unet.to(torch.float32)
1313-
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
1308+
unet_lora_layers = unet.attn_processors_state_dict
13141309

1315-
if text_encoder is not None:
1310+
if text_encoder is not None and args.train_text_encoder:
1311+
text_encoder = accelerator.unwrap_model(text_encoder)
13161312
text_encoder = text_encoder.to(torch.float32)
1317-
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
1313+
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
13181314

13191315
LoraLoaderMixin.save_lora_weights(
13201316
save_directory=args.output_dir,

src/diffusers/loaders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,40 @@ def text_encoder_attn_modules(text_encoder):
9797
return attn_modules
9898

9999

100+
def text_encoder_lora_state_dict(text_encoder):
101+
state_dict = {}
102+
103+
for name, module in text_encoder_attn_modules(text_encoder):
104+
if isinstance(text_encoder, CLIPTextModel):
105+
for k, v in module.q_proj.lora_linear_layer.state_dict():
106+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
107+
108+
for k, v in module.k_proj.lora_linear_layer.state_dict():
109+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
110+
111+
for k, v in module.v_proj.lora_linear_layer.state_dict():
112+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
113+
114+
for k, v in module.out_proj.lora_linear_layer.state_dict():
115+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
116+
elif isinstance(text_encoder, T5EncoderModel):
117+
for k, v in module.q.lora_linear_layer.state_dict():
118+
state_dict[f"{name}.q.lora_linear_layer.{k}"] = v
119+
120+
for k, v in module.k.lora_linear_layer.state_dict():
121+
state_dict[f"{name}.k.lora_linear_layer.{k}"] = v
122+
123+
for k, v in module.v.lora_linear_layer.state_dict():
124+
state_dict[f"{name}.v.lora_linear_layer.{k}"] = v
125+
126+
for k, v in module.o.lora_linear_layer.state_dict():
127+
state_dict[f"{name}.o.lora_linear_layer.{k}"] = v
128+
else:
129+
raise ValueError(f"do not know how to get state dict for: {text_encoder.__class__.__name__}")
130+
131+
return state_dict
132+
133+
100134
class AttnProcsLayers(torch.nn.Module):
101135
def __init__(self, state_dict: Dict[str, torch.Tensor]):
102136
super().__init__()

src/diffusers/models/unet_2d_condition.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,30 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
528528

529529
return processors
530530

531+
@property
532+
def attn_processors_state_dict(self) -> Dict[str, torch.tensor]:
533+
r"""
534+
Returns:
535+
a state dict containing just the attention processor parameters.
536+
"""
537+
# set recursively
538+
processors = {}
539+
540+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
541+
if hasattr(module, "set_processor"):
542+
for processor_key, processor_parameter in module.processor.state_dict():
543+
processors[f"{name}.processor.{processor_key}"] = processor_parameter
544+
545+
for sub_name, child in module.named_children():
546+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
547+
548+
return processors
549+
550+
for name, module in self.named_children():
551+
fn_recursive_add_processors(name, module, processors)
552+
553+
return processors
554+
531555
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
532556
r"""
533557
Parameters:

0 commit comments

Comments
 (0)