Skip to content

Commit 6427aa9

Browse files
authored
[Enhance] Add rank in dreambooth (#4112)
add rank in dreambooth
1 parent 8b18cd8 commit 6427aa9

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,9 @@ def main(args):
872872
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
873873
)
874874

875-
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
875+
module = lora_attn_processor_class(
876+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
877+
)
876878
unet_lora_attn_procs[name] = module
877879
unet_lora_parameters.extend(module.parameters())
878880

@@ -882,7 +884,7 @@ def main(args):
882884
# So, instead, we monkey-patch the forward calls of its attention-blocks.
883885
if args.train_text_encoder:
884886
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
885-
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)
887+
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
886888

887889
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
888890
def save_model_hook(models, weights, output_dir):
@@ -1364,7 +1366,7 @@ def compute_text_embeddings(prompt):
13641366
pipeline = pipeline.to(accelerator.device)
13651367

13661368
# load attention processors
1367-
pipeline.load_lora_weights(args.output_dir)
1369+
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
13681370

13691371
# run inference
13701372
images = []

0 commit comments

Comments
 (0)