diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index bea1979f5cb1..c79b053d6cae 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -872,7 +872,9 @@ def main(args): LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + module = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank + ) unet_lora_attn_procs[name] = module unet_lora_parameters.extend(module.parameters()) @@ -882,7 +884,7 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32) + text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1364,7 +1366,7 @@ def compute_text_embeddings(prompt): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.load_lora_weights(args.output_dir) + pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin") # run inference images = []