@@ -1785,6 +1785,11 @@ def __call__(
1785
1785
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1786
1786
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1787
1787
1788
+ if attn .norm_q is not None :
1789
+ query = attn .norm_q (query )
1790
+ if attn .norm_k is not None :
1791
+ key = attn .norm_k (key )
1792
+
1788
1793
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1789
1794
# TODO: add support for attn.scale when we move to Torch 2.1
1790
1795
hidden_states = F .scaled_dot_product_attention (
@@ -2314,6 +2319,11 @@ def __call__(
2314
2319
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2315
2320
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2316
2321
2322
+ if attn .norm_q is not None :
2323
+ query = attn .norm_q (query )
2324
+ if attn .norm_k is not None :
2325
+ key = attn .norm_k (key )
2326
+
2317
2327
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2318
2328
# TODO: add support for attn.scale when we move to Torch 2.1
2319
2329
hidden_states = F .scaled_dot_product_attention (
0 commit comments