Skip to content

Commit 44a4886

Browse files
a-r-r-o-wsayakpaul
authored andcommitted
[refactor] apply qk norm in attention processors (#9071)
* apply qk norm in attention processors * revert attention processor * qk-norm in only attention proc 2.0 and fused variant
1 parent 01829c6 commit 44a4886

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,6 +1785,11 @@ def __call__(
17851785
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
17861786
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
17871787

1788+
if attn.norm_q is not None:
1789+
query = attn.norm_q(query)
1790+
if attn.norm_k is not None:
1791+
key = attn.norm_k(key)
1792+
17881793
# the output of sdp = (batch, num_heads, seq_len, head_dim)
17891794
# TODO: add support for attn.scale when we move to Torch 2.1
17901795
hidden_states = F.scaled_dot_product_attention(
@@ -2314,6 +2319,11 @@ def __call__(
23142319
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
23152320
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
23162321

2322+
if attn.norm_q is not None:
2323+
query = attn.norm_q(query)
2324+
if attn.norm_k is not None:
2325+
key = attn.norm_k(key)
2326+
23172327
# the output of sdp = (batch, num_heads, seq_len, head_dim)
23182328
# TODO: add support for attn.scale when we move to Torch 2.1
23192329
hidden_states = F.scaled_dot_product_attention(

0 commit comments

Comments
 (0)