@@ -149,6 +149,9 @@ def set_use_memory_efficient_attention_xformers(
149
149
is_lora = hasattr (self , "processor" ) and isinstance (
150
150
self .processor , (LoRAAttnProcessor , LoRAXFormersAttnProcessor )
151
151
)
152
+ is_custom_diffusion = hasattr (self , "processor" ) and isinstance (
153
+ self .processor , (CustomDiffusionAttnProcessor , CustomDiffusionXFormersAttnProcessor )
154
+ )
152
155
153
156
if use_memory_efficient_attention_xformers :
154
157
if self .added_kv_proj_dim is not None :
@@ -192,6 +195,17 @@ def set_use_memory_efficient_attention_xformers(
192
195
)
193
196
processor .load_state_dict (self .processor .state_dict ())
194
197
processor .to (self .processor .to_q_lora .up .weight .device )
198
+ elif is_custom_diffusion :
199
+ processor = CustomDiffusionXFormersAttnProcessor (
200
+ train_kv = self .processor .train_kv ,
201
+ train_q_out = self .processor .train_q_out ,
202
+ hidden_size = self .processor .hidden_size ,
203
+ cross_attention_dim = self .processor .cross_attention_dim ,
204
+ attention_op = attention_op ,
205
+ )
206
+ processor .load_state_dict (self .processor .state_dict ())
207
+ if hasattr (self .processor , "to_k_custom_diffusion" ):
208
+ processor .to (self .processor .to_k_custom_diffusion .weight .device )
195
209
else :
196
210
processor = XFormersAttnProcessor (attention_op = attention_op )
197
211
else :
@@ -203,6 +217,16 @@ def set_use_memory_efficient_attention_xformers(
203
217
)
204
218
processor .load_state_dict (self .processor .state_dict ())
205
219
processor .to (self .processor .to_q_lora .up .weight .device )
220
+ elif is_custom_diffusion :
221
+ processor = CustomDiffusionAttnProcessor (
222
+ train_kv = self .processor .train_kv ,
223
+ train_q_out = self .processor .train_q_out ,
224
+ hidden_size = self .processor .hidden_size ,
225
+ cross_attention_dim = self .processor .cross_attention_dim ,
226
+ )
227
+ processor .load_state_dict (self .processor .state_dict ())
228
+ if hasattr (self .processor , "to_k_custom_diffusion" ):
229
+ processor .to (self .processor .to_k_custom_diffusion .weight .device )
206
230
else :
207
231
processor = AttnProcessor ()
208
232
@@ -459,6 +483,84 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
459
483
return hidden_states
460
484
461
485
486
+ class CustomDiffusionAttnProcessor (nn .Module ):
487
+ def __init__ (
488
+ self ,
489
+ train_kv = True ,
490
+ train_q_out = True ,
491
+ hidden_size = None ,
492
+ cross_attention_dim = None ,
493
+ out_bias = True ,
494
+ dropout = 0.0 ,
495
+ ):
496
+ super ().__init__ ()
497
+ self .train_kv = train_kv
498
+ self .train_q_out = train_q_out
499
+
500
+ self .hidden_size = hidden_size
501
+ self .cross_attention_dim = cross_attention_dim
502
+
503
+ # `_custom_diffusion` id for easy serialization and loading.
504
+ if self .train_kv :
505
+ self .to_k_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
506
+ self .to_v_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
507
+ if self .train_q_out :
508
+ self .to_q_custom_diffusion = nn .Linear (hidden_size , hidden_size , bias = False )
509
+ self .to_out_custom_diffusion = nn .ModuleList ([])
510
+ self .to_out_custom_diffusion .append (nn .Linear (hidden_size , hidden_size , bias = out_bias ))
511
+ self .to_out_custom_diffusion .append (nn .Dropout (dropout ))
512
+
513
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
514
+ batch_size , sequence_length , _ = hidden_states .shape
515
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
516
+ if self .train_q_out :
517
+ query = self .to_q_custom_diffusion (hidden_states )
518
+ else :
519
+ query = attn .to_q (hidden_states )
520
+
521
+ if encoder_hidden_states is None :
522
+ crossattn = False
523
+ encoder_hidden_states = hidden_states
524
+ else :
525
+ crossattn = True
526
+ if attn .norm_cross :
527
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
528
+
529
+ if self .train_kv :
530
+ key = self .to_k_custom_diffusion (encoder_hidden_states )
531
+ value = self .to_v_custom_diffusion (encoder_hidden_states )
532
+ else :
533
+ key = attn .to_k (encoder_hidden_states )
534
+ value = attn .to_v (encoder_hidden_states )
535
+
536
+ if crossattn :
537
+ detach = torch .ones_like (key )
538
+ detach [:, :1 , :] = detach [:, :1 , :] * 0.0
539
+ key = detach * key + (1 - detach ) * key .detach ()
540
+ value = detach * value + (1 - detach ) * value .detach ()
541
+
542
+ query = attn .head_to_batch_dim (query )
543
+ key = attn .head_to_batch_dim (key )
544
+ value = attn .head_to_batch_dim (value )
545
+
546
+ attention_probs = attn .get_attention_scores (query , key , attention_mask )
547
+ hidden_states = torch .bmm (attention_probs , value )
548
+ hidden_states = attn .batch_to_head_dim (hidden_states )
549
+
550
+ if self .train_q_out :
551
+ # linear proj
552
+ hidden_states = self .to_out_custom_diffusion [0 ](hidden_states )
553
+ # dropout
554
+ hidden_states = self .to_out_custom_diffusion [1 ](hidden_states )
555
+ else :
556
+ # linear proj
557
+ hidden_states = attn .to_out [0 ](hidden_states )
558
+ # dropout
559
+ hidden_states = attn .to_out [1 ](hidden_states )
560
+
561
+ return hidden_states
562
+
563
+
462
564
class AttnAddedKVProcessor :
463
565
def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
464
566
residual = hidden_states
@@ -699,6 +801,91 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
699
801
return hidden_states
700
802
701
803
804
+ class CustomDiffusionXFormersAttnProcessor (nn .Module ):
805
+ def __init__ (
806
+ self ,
807
+ train_kv = True ,
808
+ train_q_out = False ,
809
+ hidden_size = None ,
810
+ cross_attention_dim = None ,
811
+ out_bias = True ,
812
+ dropout = 0.0 ,
813
+ attention_op : Optional [Callable ] = None ,
814
+ ):
815
+ super ().__init__ ()
816
+ self .train_kv = train_kv
817
+ self .train_q_out = train_q_out
818
+
819
+ self .hidden_size = hidden_size
820
+ self .cross_attention_dim = cross_attention_dim
821
+ self .attention_op = attention_op
822
+
823
+ # `_custom_diffusion` id for easy serialization and loading.
824
+ if self .train_kv :
825
+ self .to_k_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
826
+ self .to_v_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
827
+ if self .train_q_out :
828
+ self .to_q_custom_diffusion = nn .Linear (hidden_size , hidden_size , bias = False )
829
+ self .to_out_custom_diffusion = nn .ModuleList ([])
830
+ self .to_out_custom_diffusion .append (nn .Linear (hidden_size , hidden_size , bias = out_bias ))
831
+ self .to_out_custom_diffusion .append (nn .Dropout (dropout ))
832
+
833
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
834
+ batch_size , sequence_length , _ = (
835
+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
836
+ )
837
+
838
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
839
+
840
+ if self .train_q_out :
841
+ query = self .to_q_custom_diffusion (hidden_states )
842
+ else :
843
+ query = attn .to_q (hidden_states )
844
+
845
+ if encoder_hidden_states is None :
846
+ crossattn = False
847
+ encoder_hidden_states = hidden_states
848
+ else :
849
+ crossattn = True
850
+ if attn .norm_cross :
851
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
852
+
853
+ if self .train_kv :
854
+ key = self .to_k_custom_diffusion (encoder_hidden_states )
855
+ value = self .to_v_custom_diffusion (encoder_hidden_states )
856
+ else :
857
+ key = attn .to_k (encoder_hidden_states )
858
+ value = attn .to_v (encoder_hidden_states )
859
+
860
+ if crossattn :
861
+ detach = torch .ones_like (key )
862
+ detach [:, :1 , :] = detach [:, :1 , :] * 0.0
863
+ key = detach * key + (1 - detach ) * key .detach ()
864
+ value = detach * value + (1 - detach ) * value .detach ()
865
+
866
+ query = attn .head_to_batch_dim (query ).contiguous ()
867
+ key = attn .head_to_batch_dim (key ).contiguous ()
868
+ value = attn .head_to_batch_dim (value ).contiguous ()
869
+
870
+ hidden_states = xformers .ops .memory_efficient_attention (
871
+ query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn .scale
872
+ )
873
+ hidden_states = hidden_states .to (query .dtype )
874
+ hidden_states = attn .batch_to_head_dim (hidden_states )
875
+
876
+ if self .train_q_out :
877
+ # linear proj
878
+ hidden_states = self .to_out_custom_diffusion [0 ](hidden_states )
879
+ # dropout
880
+ hidden_states = self .to_out_custom_diffusion [1 ](hidden_states )
881
+ else :
882
+ # linear proj
883
+ hidden_states = attn .to_out [0 ](hidden_states )
884
+ # dropout
885
+ hidden_states = attn .to_out [1 ](hidden_states )
886
+ return hidden_states
887
+
888
+
702
889
class SlicedAttnProcessor :
703
890
def __init__ (self , slice_size ):
704
891
self .slice_size = slice_size
@@ -834,4 +1021,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
834
1021
AttnAddedKVProcessor2_0 ,
835
1022
LoRAAttnProcessor ,
836
1023
LoRAXFormersAttnProcessor ,
1024
+ CustomDiffusionAttnProcessor ,
1025
+ CustomDiffusionXFormersAttnProcessor ,
837
1026
]
0 commit comments