Skip to content

Commit 0670b76

Browse files
committed
add group norm type to attention processor cross attention norm
This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states`
1 parent b8d88b8 commit 0670b76

File tree

3 files changed

+51
-29
lines changed

3 files changed

+51
-29
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
bias=False,
5757
upcast_attention: bool = False,
5858
upcast_softmax: bool = False,
59-
cross_attention_norm: bool = False,
59+
cross_attention_norm: Union[bool, str] = False,
60+
cross_attention_norm_num_groups: int = 32,
6061
added_kv_proj_dim: Optional[int] = None,
6162
norm_num_groups: Optional[int] = None,
6263
out_bias: bool = True,
@@ -68,7 +69,6 @@ def __init__(
6869
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
6970
self.upcast_attention = upcast_attention
7071
self.upcast_softmax = upcast_softmax
71-
self.cross_attention_norm = cross_attention_norm
7272

7373
self.scale = dim_head**-0.5 if scale_qk else 1.0
7474

@@ -85,8 +85,28 @@ def __init__(
8585
else:
8686
self.group_norm = None
8787

88-
if cross_attention_norm:
88+
if cross_attention_norm is False:
89+
self.norm_cross = None
90+
elif cross_attention_norm is True or cross_attention_norm == "layer_norm":
8991
self.norm_cross = nn.LayerNorm(cross_attention_dim)
92+
elif cross_attention_norm == "group_norm":
93+
if self.added_kv_proj_dim is not None:
94+
# The given `encoder_hidden_states` are initially of shape
95+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
96+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
97+
# before the projection, so we need to use `added_kv_proj_dim` as
98+
# the number of channels for the group norm.
99+
norm_cross_num_channels = added_kv_proj_dim
100+
else:
101+
norm_cross_num_channels = cross_attention_dim
102+
103+
self.norm_cross = nn.GroupNorm(
104+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
105+
)
106+
else:
107+
raise ValueError(
108+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be False, True, 'layer_norm' or 'group_norm'"
109+
)
90110

91111
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92112
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
@@ -291,6 +311,24 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
291311
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
292312
return attention_mask
293313

314+
def prepare_encoder_hidden_states(self, hidden_states, encoder_hidden_states=None):
315+
if encoder_hidden_states is None:
316+
return hidden_states
317+
318+
if self.norm_cross is None:
319+
return encoder_hidden_states
320+
321+
if isinstance(self.norm_cross, nn.LayerNorm):
322+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
323+
elif isinstance(self.norm_cross, nn.GroupNorm):
324+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
325+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
326+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
327+
else:
328+
assert False
329+
330+
return encoder_hidden_states
331+
294332

295333
class AttnProcessor:
296334
def __call__(
@@ -306,10 +344,7 @@ def __call__(
306344
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
307345
query = attn.to_q(hidden_states)
308346

309-
if encoder_hidden_states is None:
310-
encoder_hidden_states = hidden_states
311-
elif attn.cross_attention_norm:
312-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
347+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
313348

314349
key = attn.to_k(encoder_hidden_states)
315350
value = attn.to_v(encoder_hidden_states)
@@ -375,7 +410,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
375410
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
376411
query = attn.head_to_batch_dim(query)
377412

378-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
413+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
379414

380415
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
381416
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -402,6 +437,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
402437
batch_size, sequence_length, _ = hidden_states.shape
403438

404439
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
440+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
405441

406442
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
407443

@@ -449,10 +485,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
449485

450486
query = attn.to_q(hidden_states)
451487

452-
if encoder_hidden_states is None:
453-
encoder_hidden_states = hidden_states
454-
elif attn.cross_attention_norm:
455-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
488+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
456489

457490
key = attn.to_k(encoder_hidden_states)
458491
value = attn.to_v(encoder_hidden_states)
@@ -493,10 +526,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
493526

494527
query = attn.to_q(hidden_states)
495528

496-
if encoder_hidden_states is None:
497-
encoder_hidden_states = hidden_states
498-
elif attn.cross_attention_norm:
499-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
529+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
500530

501531
key = attn.to_k(encoder_hidden_states)
502532
value = attn.to_v(encoder_hidden_states)
@@ -545,7 +575,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
545575
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
546576
query = attn.head_to_batch_dim(query).contiguous()
547577

548-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
578+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
549579

550580
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
551581
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -580,10 +610,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
580610
dim = query.shape[-1]
581611
query = attn.head_to_batch_dim(query)
582612

583-
if encoder_hidden_states is None:
584-
encoder_hidden_states = hidden_states
585-
elif attn.cross_attention_norm:
586-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
613+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
587614

588615
key = attn.to_k(encoder_hidden_states)
589616
value = attn.to_v(encoder_hidden_states)
@@ -630,6 +657,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
630657
batch_size, sequence_length, _ = hidden_states.shape
631658

632659
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
660+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
633661

634662
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
635663

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,7 @@ def __call__(
241241
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
242242
query = attn.to_q(hidden_states)
243243

244-
if encoder_hidden_states is None:
245-
encoder_hidden_states = hidden_states
246-
elif attn.cross_attention_norm:
247-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
244+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
248245

249246
key = attn.to_k(encoder_hidden_states)
250247
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def __call__(
6363
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
6464
query = attn.to_q(hidden_states)
6565

66-
if encoder_hidden_states is None:
67-
encoder_hidden_states = hidden_states
68-
elif attn.cross_attention_norm:
69-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
66+
encoder_hidden_states = attn.prepare_encoder_hidden_states(hidden_states, encoder_hidden_states)
7067

7168
key = attn.to_k(encoder_hidden_states)
7269
value = attn.to_v(encoder_hidden_states)

0 commit comments

Comments
 (0)