Skip to content

Commit c3307c3

Browse files
committed
add group norm type to attention processor cross attention norm
This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states`
1 parent b8d88b8 commit c3307c3

File tree

3 files changed

+56
-29
lines changed

3 files changed

+56
-29
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
bias=False,
5757
upcast_attention: bool = False,
5858
upcast_softmax: bool = False,
59-
cross_attention_norm: bool = False,
59+
cross_attention_norm: Union[bool, str] = False,
60+
cross_attention_norm_num_groups: int = 32,
6061
added_kv_proj_dim: Optional[int] = None,
6162
norm_num_groups: Optional[int] = None,
6263
out_bias: bool = True,
@@ -68,7 +69,6 @@ def __init__(
6869
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
6970
self.upcast_attention = upcast_attention
7071
self.upcast_softmax = upcast_softmax
71-
self.cross_attention_norm = cross_attention_norm
7272

7373
self.scale = dim_head**-0.5 if scale_qk else 1.0
7474

@@ -85,8 +85,28 @@ def __init__(
8585
else:
8686
self.group_norm = None
8787

88-
if cross_attention_norm:
88+
if cross_attention_norm is False:
89+
self.norm_cross = None
90+
elif cross_attention_norm is True or cross_attention_norm == "layer_norm":
8991
self.norm_cross = nn.LayerNorm(cross_attention_dim)
92+
elif cross_attention_norm == "group_norm":
93+
if self.added_kv_proj_dim is not None:
94+
# The given `encoder_hidden_states` are initially of shape
95+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
96+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
97+
# before the projection, so we need to use `added_kv_proj_dim` as
98+
# the number of channels for the group norm.
99+
norm_cross_num_channels = added_kv_proj_dim
100+
else:
101+
norm_cross_num_channels = cross_attention_dim
102+
103+
self.norm_cross = nn.GroupNorm(
104+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
105+
)
106+
else:
107+
raise ValueError(
108+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be False, True, 'layer_norm' or 'group_norm'"
109+
)
90110

91111
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92112
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
@@ -291,6 +311,29 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
291311
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
292312
return attention_mask
293313

314+
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
315+
if encoder_hidden_states is None:
316+
return hidden_states
317+
318+
if self.norm_cross is None:
319+
return encoder_hidden_states
320+
321+
if isinstance(self.norm_cross, nn.LayerNorm):
322+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
323+
elif isinstance(self.norm_cross, nn.GroupNorm):
324+
# Group norm norms along the channels dimension and expects
325+
# input to be in the shape of (N, C, *). In this case, we want
326+
# to norm along the hidden dimension, so we need to move
327+
# (batch_size, sequence_length, hidden_size) ->
328+
# (batch_size, hidden_size, sequence_length)
329+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
330+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
331+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
332+
else:
333+
assert False
334+
335+
return encoder_hidden_states
336+
294337

295338
class AttnProcessor:
296339
def __call__(
@@ -306,10 +349,7 @@ def __call__(
306349
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
307350
query = attn.to_q(hidden_states)
308351

309-
if encoder_hidden_states is None:
310-
encoder_hidden_states = hidden_states
311-
elif attn.cross_attention_norm:
312-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
352+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
313353

314354
key = attn.to_k(encoder_hidden_states)
315355
value = attn.to_v(encoder_hidden_states)
@@ -375,7 +415,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
375415
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
376416
query = attn.head_to_batch_dim(query)
377417

378-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
418+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
379419

380420
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
381421
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -402,6 +442,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
402442
batch_size, sequence_length, _ = hidden_states.shape
403443

404444
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
445+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
405446

406447
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
407448

@@ -449,10 +490,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
449490

450491
query = attn.to_q(hidden_states)
451492

452-
if encoder_hidden_states is None:
453-
encoder_hidden_states = hidden_states
454-
elif attn.cross_attention_norm:
455-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
493+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
456494

457495
key = attn.to_k(encoder_hidden_states)
458496
value = attn.to_v(encoder_hidden_states)
@@ -493,10 +531,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
493531

494532
query = attn.to_q(hidden_states)
495533

496-
if encoder_hidden_states is None:
497-
encoder_hidden_states = hidden_states
498-
elif attn.cross_attention_norm:
499-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
534+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
500535

501536
key = attn.to_k(encoder_hidden_states)
502537
value = attn.to_v(encoder_hidden_states)
@@ -545,7 +580,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
545580
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
546581
query = attn.head_to_batch_dim(query).contiguous()
547582

548-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
583+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
549584

550585
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
551586
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -580,10 +615,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580615
dim = query.shape[-1]
581616
query = attn.head_to_batch_dim(query)
582617

583-
if encoder_hidden_states is None:
584-
encoder_hidden_states = hidden_states
585-
elif attn.cross_attention_norm:
586-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
618+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
587619

588620
key = attn.to_k(encoder_hidden_states)
589621
value = attn.to_v(encoder_hidden_states)
@@ -630,6 +662,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
630662
batch_size, sequence_length, _ = hidden_states.shape
631663

632664
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
665+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
633666

634667
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
635668

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,7 @@ def __call__(
241241
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
242242
query = attn.to_q(hidden_states)
243243

244-
if encoder_hidden_states is None:
245-
encoder_hidden_states = hidden_states
246-
elif attn.cross_attention_norm:
247-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
244+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
248245

249246
key = attn.to_k(encoder_hidden_states)
250247
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def __call__(
6363
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
6464
query = attn.to_q(hidden_states)
6565

66-
if encoder_hidden_states is None:
67-
encoder_hidden_states = hidden_states
68-
elif attn.cross_attention_norm:
69-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
66+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
7067

7168
key = attn.to_k(encoder_hidden_states)
7269
value = attn.to_v(encoder_hidden_states)

0 commit comments

Comments
 (0)