@@ -312,8 +312,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
312
312
return attention_mask
313
313
314
314
def norm_encoder_hidden_states (self , encoder_hidden_states ):
315
- if self .norm_cross is None :
316
- return encoder_hidden_states
315
+ assert self .norm_cross is not None , "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
317
316
318
317
if isinstance (self .norm_cross , nn .LayerNorm ):
319
318
encoder_hidden_states = self .norm_cross (encoder_hidden_states )
@@ -348,7 +347,7 @@ def __call__(
348
347
349
348
if encoder_hidden_states is None :
350
349
encoder_hidden_states = hidden_states
351
- else :
350
+ elif attn . norm_cross :
352
351
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
353
352
354
353
key = attn .to_k (encoder_hidden_states )
@@ -448,7 +447,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
448
447
449
448
if encoder_hidden_states is None :
450
449
encoder_hidden_states = hidden_states
451
- else :
450
+ elif attn . norm_cross :
452
451
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
453
452
454
453
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
@@ -499,7 +498,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
499
498
500
499
if encoder_hidden_states is None :
501
500
encoder_hidden_states = hidden_states
502
- else :
501
+ elif attn . norm_cross :
503
502
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
504
503
505
504
key = attn .to_k (encoder_hidden_states )
@@ -543,7 +542,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
543
542
544
543
if encoder_hidden_states is None :
545
544
encoder_hidden_states = hidden_states
546
- else :
545
+ elif attn . norm_cross :
547
546
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
548
547
549
548
key = attn .to_k (encoder_hidden_states )
@@ -595,7 +594,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
595
594
596
595
if encoder_hidden_states is None :
597
596
encoder_hidden_states = hidden_states
598
- else :
597
+ elif attn . norm_cross :
599
598
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
600
599
601
600
key = attn .to_k (encoder_hidden_states ) + scale * self .to_k_lora (encoder_hidden_states )
@@ -633,7 +632,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
633
632
634
633
if encoder_hidden_states is None :
635
634
encoder_hidden_states = hidden_states
636
- else :
635
+ elif attn . norm_cross :
637
636
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
638
637
639
638
key = attn .to_k (encoder_hidden_states )
@@ -684,7 +683,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
684
683
685
684
if encoder_hidden_states is None :
686
685
encoder_hidden_states = hidden_states
687
- else :
686
+ elif attn . norm_cross :
688
687
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
689
688
690
689
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
0 commit comments