@@ -56,7 +56,8 @@ def __init__(
56
56
bias = False ,
57
57
upcast_attention : bool = False ,
58
58
upcast_softmax : bool = False ,
59
- cross_attention_norm : bool = False ,
59
+ cross_attention_norm : Union [bool , str ] = False ,
60
+ cross_attention_norm_num_groups : int = 32 ,
60
61
added_kv_proj_dim : Optional [int ] = None ,
61
62
norm_num_groups : Optional [int ] = None ,
62
63
out_bias : bool = True ,
@@ -68,7 +69,6 @@ def __init__(
68
69
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
69
70
self .upcast_attention = upcast_attention
70
71
self .upcast_softmax = upcast_softmax
71
- self .cross_attention_norm = cross_attention_norm
72
72
73
73
self .scale = dim_head ** - 0.5 if scale_qk else 1.0
74
74
@@ -85,8 +85,28 @@ def __init__(
85
85
else :
86
86
self .group_norm = None
87
87
88
- if cross_attention_norm :
88
+ if cross_attention_norm is False :
89
+ self .norm_cross = None
90
+ elif cross_attention_norm is True or cross_attention_norm == "layer_norm" :
89
91
self .norm_cross = nn .LayerNorm (cross_attention_dim )
92
+ elif cross_attention_norm == "group_norm" :
93
+ if self .added_kv_proj_dim is not None :
94
+ # The given `encoder_hidden_states` are initially of shape
95
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
96
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
97
+ # before the projection, so we need to use `added_kv_proj_dim` as
98
+ # the number of channels for the group norm.
99
+ norm_cross_num_channels = added_kv_proj_dim
100
+ else :
101
+ norm_cross_num_channels = cross_attention_dim
102
+
103
+ self .norm_cross = nn .GroupNorm (
104
+ num_channels = norm_cross_num_channels , num_groups = cross_attention_norm_num_groups , eps = 1e-5 , affine = True
105
+ )
106
+ else :
107
+ raise ValueError (
108
+ f"unknown cross_attention_norm: { cross_attention_norm } . Should be False, True, 'layer_norm' or 'group_norm'"
109
+ )
90
110
91
111
self .to_q = nn .Linear (query_dim , inner_dim , bias = bias )
92
112
self .to_k = nn .Linear (cross_attention_dim , inner_dim , bias = bias )
@@ -291,6 +311,24 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
291
311
attention_mask = attention_mask .repeat_interleave (head_size , dim = 0 )
292
312
return attention_mask
293
313
314
+ def prepare_encoder_hidden_states (self , hidden_states , encoder_hidden_states = None ):
315
+ if encoder_hidden_states is None :
316
+ return hidden_states
317
+
318
+ if self .norm_cross is None :
319
+ return encoder_hidden_states
320
+
321
+ if isinstance (self .norm_cross , nn .LayerNorm ):
322
+ encoder_hidden_states = self .norm_cross (encoder_hidden_states )
323
+ elif isinstance (self .norm_cross , nn .GroupNorm ):
324
+ encoder_hidden_states = encoder_hidden_states .transpose (1 , 2 )
325
+ encoder_hidden_states = self .norm_cross (encoder_hidden_states )
326
+ encoder_hidden_states = encoder_hidden_states .transpose (1 , 2 )
327
+ else :
328
+ assert False
329
+
330
+ return encoder_hidden_states
331
+
294
332
295
333
class AttnProcessor :
296
334
def __call__ (
@@ -306,10 +344,7 @@ def __call__(
306
344
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
307
345
query = attn .to_q (hidden_states )
308
346
309
- if encoder_hidden_states is None :
310
- encoder_hidden_states = hidden_states
311
- elif attn .cross_attention_norm :
312
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
347
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
313
348
314
349
key = attn .to_k (encoder_hidden_states )
315
350
value = attn .to_v (encoder_hidden_states )
@@ -375,7 +410,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
375
410
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
376
411
query = attn .head_to_batch_dim (query )
377
412
378
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
413
+ encoder_hidden_states = attn . prepare_encoder_hidden_states ( hidden_states , encoder_hidden_states )
379
414
380
415
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
381
416
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -402,6 +437,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
402
437
batch_size , sequence_length , _ = hidden_states .shape
403
438
404
439
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
440
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
405
441
406
442
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
407
443
@@ -449,10 +485,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
449
485
450
486
query = attn .to_q (hidden_states )
451
487
452
- if encoder_hidden_states is None :
453
- encoder_hidden_states = hidden_states
454
- elif attn .cross_attention_norm :
455
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
488
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
456
489
457
490
key = attn .to_k (encoder_hidden_states )
458
491
value = attn .to_v (encoder_hidden_states )
@@ -493,10 +526,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
493
526
494
527
query = attn .to_q (hidden_states )
495
528
496
- if encoder_hidden_states is None :
497
- encoder_hidden_states = hidden_states
498
- elif attn .cross_attention_norm :
499
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
529
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
500
530
501
531
key = attn .to_k (encoder_hidden_states )
502
532
value = attn .to_v (encoder_hidden_states )
@@ -545,7 +575,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
545
575
query = attn .to_q (hidden_states ) + scale * self .to_q_lora (hidden_states )
546
576
query = attn .head_to_batch_dim (query ).contiguous ()
547
577
548
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
578
+ encoder_hidden_states = attn . prepare_encoder_hidden_states ( hidden_states , encoder_hidden_states )
549
579
550
580
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
551
581
value = attn .to_v (encoder_hidden_states ) + scale * self .to_v_lora (encoder_hidden_states )
@@ -580,10 +610,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580
610
dim = query .shape [- 1 ]
581
611
query = attn .head_to_batch_dim (query )
582
612
583
- if encoder_hidden_states is None :
584
- encoder_hidden_states = hidden_states
585
- elif attn .cross_attention_norm :
586
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
613
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
587
614
588
615
key = attn .to_k (encoder_hidden_states )
589
616
value = attn .to_v (encoder_hidden_states )
@@ -630,6 +657,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
630
657
batch_size , sequence_length , _ = hidden_states .shape
631
658
632
659
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
660
+ encoder_hidden_states = attn .prepare_encoder_hidden_states (hidden_states , encoder_hidden_states )
633
661
634
662
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
635
663
0 commit comments