diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index a9dac59051a..ce3b01b6d68 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -212,7 +212,6 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.use_qk_norm = config.use_qk_norm self.use_conv2d = False - assert not self.use_qk_norm, "QK norm not supported in static attention yet" self.wqs = nn.ModuleList( [ nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) @@ -241,6 +240,13 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.rope = _Rope(rope.params.use_hf_rope) + if self.use_qk_norm: + self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + else: + self.q_norm = torch.nn.Identity() + self.k_norm = torch.nn.Identity() + def forward( self, x: torch.Tensor, @@ -275,6 +281,10 @@ def from_conv2ds(ts): new_ks = from_conv2ds(new_ks) new_vs = from_conv2ds(new_vs) + if self.use_qk_norm: + new_qs = [self.q_norm(q) for q in new_qs] + new_ks = [self.k_norm(k) for k in new_ks] + new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs] new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks] all_ks = [] @@ -325,6 +335,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA): self.wo.weight.data.copy_(other.wo.weight) + if other.use_qk_norm: + self.use_qk_norm = True + self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps) + self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) + self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps) + self.k_norm.load_state_dict(other.k_norm_fn.state_dict()) + def linear_to_conv2d(self): def transfer_weight(linear, conv2d): conv2d.weight.data.copy_(linear.weight[:, :, None, None]) diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 2f6f9639b55..a1b6742416e 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -17,12 +17,13 @@ def setUp(self): torch.manual_seed(42) def test_without_cache(self): - def test(use_conv2d): + def test(use_qk_norm, use_conv2d): config = ModelArgs( dim=64, n_heads=4, n_kv_heads=2, max_seq_len=8, + use_qk_norm=use_qk_norm, ) layer_id = 0 rope = Rope(config) @@ -47,8 +48,10 @@ def test(use_conv2d): ) self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) - test(True) - test(False) + test(True, True) + test(True, False) + test(False, True) + test(False, False) def test_hf_rope_without_cache(self): config = ModelArgs(