41
41
torch .backends .cuda .matmul .allow_tf32 = False
42
42
43
43
44
- def create_lora_layers (model ):
44
+ def create_lora_layers (model , mock_weights : bool = True ):
45
45
lora_attn_procs = {}
46
46
for name in model .attn_processors .keys ():
47
47
cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
@@ -57,12 +57,13 @@ def create_lora_layers(model):
57
57
lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
58
58
lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
59
59
60
- # add 1 to weights to mock trained weights
61
- with torch .no_grad ():
62
- lora_attn_procs [name ].to_q_lora .up .weight += 1
63
- lora_attn_procs [name ].to_k_lora .up .weight += 1
64
- lora_attn_procs [name ].to_v_lora .up .weight += 1
65
- lora_attn_procs [name ].to_out_lora .up .weight += 1
60
+ if mock_weights :
61
+ # add 1 to weights to mock trained weights
62
+ with torch .no_grad ():
63
+ lora_attn_procs [name ].to_q_lora .up .weight += 1
64
+ lora_attn_procs [name ].to_k_lora .up .weight += 1
65
+ lora_attn_procs [name ].to_v_lora .up .weight += 1
66
+ lora_attn_procs [name ].to_out_lora .up .weight += 1
66
67
67
68
return lora_attn_procs
68
69
@@ -378,26 +379,7 @@ def test_lora_processors(self):
378
379
with torch .no_grad ():
379
380
sample1 = model (** inputs_dict ).sample
380
381
381
- lora_attn_procs = {}
382
- for name in model .attn_processors .keys ():
383
- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
384
- if name .startswith ("mid_block" ):
385
- hidden_size = model .config .block_out_channels [- 1 ]
386
- elif name .startswith ("up_blocks" ):
387
- block_id = int (name [len ("up_blocks." )])
388
- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
389
- elif name .startswith ("down_blocks" ):
390
- block_id = int (name [len ("down_blocks." )])
391
- hidden_size = model .config .block_out_channels [block_id ]
392
-
393
- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
394
-
395
- # add 1 to weights to mock trained weights
396
- with torch .no_grad ():
397
- lora_attn_procs [name ].to_q_lora .up .weight += 1
398
- lora_attn_procs [name ].to_k_lora .up .weight += 1
399
- lora_attn_procs [name ].to_v_lora .up .weight += 1
400
- lora_attn_procs [name ].to_out_lora .up .weight += 1
382
+ lora_attn_procs = create_lora_layers (model )
401
383
402
384
# make sure we can set a list of attention processors
403
385
model .set_attn_processor (lora_attn_procs )
@@ -465,28 +447,7 @@ def test_lora_save_load_safetensors(self):
465
447
with torch .no_grad ():
466
448
old_sample = model (** inputs_dict ).sample
467
449
468
- lora_attn_procs = {}
469
- for name in model .attn_processors .keys ():
470
- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
471
- if name .startswith ("mid_block" ):
472
- hidden_size = model .config .block_out_channels [- 1 ]
473
- elif name .startswith ("up_blocks" ):
474
- block_id = int (name [len ("up_blocks." )])
475
- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
476
- elif name .startswith ("down_blocks" ):
477
- block_id = int (name [len ("down_blocks." )])
478
- hidden_size = model .config .block_out_channels [block_id ]
479
-
480
- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
481
- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
482
-
483
- # add 1 to weights to mock trained weights
484
- with torch .no_grad ():
485
- lora_attn_procs [name ].to_q_lora .up .weight += 1
486
- lora_attn_procs [name ].to_k_lora .up .weight += 1
487
- lora_attn_procs [name ].to_v_lora .up .weight += 1
488
- lora_attn_procs [name ].to_out_lora .up .weight += 1
489
-
450
+ lora_attn_procs = create_lora_layers (model )
490
451
model .set_attn_processor (lora_attn_procs )
491
452
492
453
with torch .no_grad ():
@@ -518,21 +479,7 @@ def test_lora_save_safetensors_load_torch(self):
518
479
model = self .model_class (** init_dict )
519
480
model .to (torch_device )
520
481
521
- lora_attn_procs = {}
522
- for name in model .attn_processors .keys ():
523
- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
524
- if name .startswith ("mid_block" ):
525
- hidden_size = model .config .block_out_channels [- 1 ]
526
- elif name .startswith ("up_blocks" ):
527
- block_id = int (name [len ("up_blocks." )])
528
- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
529
- elif name .startswith ("down_blocks" ):
530
- block_id = int (name [len ("down_blocks." )])
531
- hidden_size = model .config .block_out_channels [block_id ]
532
-
533
- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
534
- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
535
-
482
+ lora_attn_procs = create_lora_layers (model , mock_weights = False )
536
483
model .set_attn_processor (lora_attn_procs )
537
484
# Saving as torch, properly reloads with directly filename
538
485
with tempfile .TemporaryDirectory () as tmpdirname :
@@ -553,21 +500,7 @@ def test_lora_save_torch_force_load_safetensors_error(self):
553
500
model = self .model_class (** init_dict )
554
501
model .to (torch_device )
555
502
556
- lora_attn_procs = {}
557
- for name in model .attn_processors .keys ():
558
- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
559
- if name .startswith ("mid_block" ):
560
- hidden_size = model .config .block_out_channels [- 1 ]
561
- elif name .startswith ("up_blocks" ):
562
- block_id = int (name [len ("up_blocks." )])
563
- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
564
- elif name .startswith ("down_blocks" ):
565
- block_id = int (name [len ("down_blocks." )])
566
- hidden_size = model .config .block_out_channels [block_id ]
567
-
568
- lora_attn_procs [name ] = LoRAAttnProcessor (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
569
- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
570
-
503
+ lora_attn_procs = create_lora_layers (model , mock_weights = False )
571
504
model .set_attn_processor (lora_attn_procs )
572
505
# Saving as torch, properly reloads with directly filename
573
506
with tempfile .TemporaryDirectory () as tmpdirname :
0 commit comments