Skip to content

Commit d97a32a

Browse files
committed
move norm_cross defined check to outside norm_encoder_hidden_states
1 parent 3cf771f commit d97a32a

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
312312
return attention_mask
313313

314314
def norm_encoder_hidden_states(self, encoder_hidden_states):
315-
if self.norm_cross is None:
316-
return encoder_hidden_states
315+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
317316

318317
if isinstance(self.norm_cross, nn.LayerNorm):
319318
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
@@ -348,7 +347,7 @@ def __call__(
348347

349348
if encoder_hidden_states is None:
350349
encoder_hidden_states = hidden_states
351-
else:
350+
elif attn.norm_cross:
352351
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
353352

354353
key = attn.to_k(encoder_hidden_states)
@@ -448,7 +447,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
448447

449448
if encoder_hidden_states is None:
450449
encoder_hidden_states = hidden_states
451-
else:
450+
elif attn.norm_cross:
452451
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
453452

454453
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -499,7 +498,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
499498

500499
if encoder_hidden_states is None:
501500
encoder_hidden_states = hidden_states
502-
else:
501+
elif attn.norm_cross:
503502
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
504503

505504
key = attn.to_k(encoder_hidden_states)
@@ -543,7 +542,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
543542

544543
if encoder_hidden_states is None:
545544
encoder_hidden_states = hidden_states
546-
else:
545+
elif attn.norm_cross:
547546
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
548547

549548
key = attn.to_k(encoder_hidden_states)
@@ -595,7 +594,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
595594

596595
if encoder_hidden_states is None:
597596
encoder_hidden_states = hidden_states
598-
else:
597+
elif attn.norm_cross:
599598
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
600599

601600
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
@@ -633,7 +632,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
633632

634633
if encoder_hidden_states is None:
635634
encoder_hidden_states = hidden_states
636-
else:
635+
elif attn.norm_cross:
637636
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
638637

639638
key = attn.to_k(encoder_hidden_states)
@@ -684,7 +683,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
684683

685684
if encoder_hidden_states is None:
686685
encoder_hidden_states = hidden_states
687-
else:
686+
elif attn.norm_cross:
688687
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
689688

690689
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def __call__(
243243

244244
if encoder_hidden_states is None:
245245
encoder_hidden_states = hidden_states
246-
else:
246+
elif attn.norm_cross:
247247
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
248248

249249
key = attn.to_k(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __call__(
6565

6666
if encoder_hidden_states is None:
6767
encoder_hidden_states = hidden_states
68-
else:
68+
elif attn.norm_cross:
6969
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
7070

7171
key = attn.to_k(encoder_hidden_states)

0 commit comments

Comments
 (0)