@@ -872,7 +872,9 @@ def main(args):
872
872
LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
873
873
)
874
874
875
- module = lora_attn_processor_class (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
875
+ module = lora_attn_processor_class (
876
+ hidden_size = hidden_size , cross_attention_dim = cross_attention_dim , rank = args .rank
877
+ )
876
878
unet_lora_attn_procs [name ] = module
877
879
unet_lora_parameters .extend (module .parameters ())
878
880
@@ -882,7 +884,7 @@ def main(args):
882
884
# So, instead, we monkey-patch the forward calls of its attention-blocks.
883
885
if args .train_text_encoder :
884
886
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
885
- text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 )
887
+ text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 , rank = args . rank )
886
888
887
889
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
888
890
def save_model_hook (models , weights , output_dir ):
@@ -1364,7 +1366,7 @@ def compute_text_embeddings(prompt):
1364
1366
pipeline = pipeline .to (accelerator .device )
1365
1367
1366
1368
# load attention processors
1367
- pipeline .load_lora_weights (args .output_dir )
1369
+ pipeline .load_lora_weights (args .output_dir , weight_name = "pytorch_lora_weights.bin" )
1368
1370
1369
1371
# run inference
1370
1372
images = []
0 commit comments