From 5706bc1be6c60525f3a93c21760e43e7a6e4421e Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Tue, 11 Mar 2025 09:39:47 -0700 Subject: [PATCH] Linear to Conv2d transform for static attention (#9025) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/9025 This is needed by some backends such as HTP. Reviewed By: billmguo Differential Revision: D70726317 --- examples/models/llama/static_attention.py | 67 ++++++++++++++++++- .../llama/tests/test_static_attention.py | 56 +++++++++------- 2 files changed, 97 insertions(+), 26 deletions(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index f35efa38151..a9dac59051a 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -210,6 +210,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) self.attention_qkv_bias = config.attention_qkv_bias self.use_qk_norm = config.use_qk_norm + self.use_conv2d = False assert not self.use_qk_norm, "QK norm not supported in static attention yet" self.wqs = nn.ModuleList( @@ -255,9 +256,25 @@ def forward( in_cache_state = kwargs.get("in_cache_state") out_cache_state = kwargs.get("out_cache_state") + bsz, seq_len, dim = x.shape + if self.use_conv2d: + x = x.reshape(bsz, seq_len, 1, dim).transpose(1, 3) + new_qs = [self.wqs[i](x) for i in range(self.n_heads)] new_ks = [self.wks[i](x) for i in range(self.n_kv_heads)] new_vs = [self.wvs[i](x) for i in range(self.n_kv_heads)] + + if self.use_conv2d: + + def from_conv2ds(ts): + return [ + t.reshape(bsz, self.head_dim, seq_len).transpose(1, 2) for t in ts + ] + + new_qs = from_conv2ds(new_qs) + new_ks = from_conv2ds(new_ks) + new_vs = from_conv2ds(new_vs) + new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs] new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks] all_ks = [] @@ -282,7 +299,14 @@ def forward( heads.append(attn @ all_vs[kv_idx]) y = torch.cat(heads, dim=-1) - y = self.wo(y) + if self.use_conv2d: + y = ( + self.wo(y.reshape(bsz, seq_len, 1, -1).transpose(1, 3)) + .transpose(1, 3) + .reshape(bsz, seq_len, -1) + ) + else: + y = self.wo(y) return y, {"out_cache_state": out_cache_state} def load_weights_from_attention_mha(self, other: AttentionMHA): @@ -300,3 +324,44 @@ def load_weights_from_attention_mha(self, other: AttentionMHA): ) self.wo.weight.data.copy_(other.wo.weight) + + def linear_to_conv2d(self): + def transfer_weight(linear, conv2d): + conv2d.weight.data.copy_(linear.weight[:, :, None, None]) + return conv2d + + self.wqs = nn.ModuleList( + [ + transfer_weight( + linear, + nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias), + ) + for linear in self.wqs + ] + ) + self.wks = nn.ModuleList( + [ + transfer_weight( + linear, + nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias), + ) + for linear in self.wks + ] + ) + self.wvs = nn.ModuleList( + [ + transfer_weight( + linear, + nn.Conv2d(self.dim, self.head_dim, 1, bias=self.attention_qkv_bias), + ) + for linear in self.wvs + ] + ) + self.wo = transfer_weight( + self.wo, + nn.Conv2d( + self.n_heads * self.head_dim, self.dim, 1, bias=self.attention_qkv_bias + ), + ) + + self.use_conv2d = True diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 45364b1d5ec..2f6f9639b55 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -17,32 +17,38 @@ def setUp(self): torch.manual_seed(42) def test_without_cache(self): - config = ModelArgs( - dim=64, - n_heads=4, - n_kv_heads=2, - max_seq_len=8, - ) - layer_id = 0 - rope = Rope(config) - attn_mha = AttentionMHA(config, layer_id, rope).eval() - static_attn = StaticAttention(config, layer_id, rope).eval() - static_attn.load_weights_from_attention_mha(attn_mha) + def test(use_conv2d): + config = ModelArgs( + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=8, + ) + layer_id = 0 + rope = Rope(config) + attn_mha = AttentionMHA(config, layer_id, rope).eval() + static_attn = StaticAttention(config, layer_id, rope).eval() + static_attn.load_weights_from_attention_mha(attn_mha) + if use_conv2d: + static_attn.linear_to_conv2d() + + x = torch.rand(1, config.max_seq_len, config.dim) + freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) + expected, _ = attn_mha(x, freqs_cos, freqs_sin) + mask = torch.triu( + torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), + diagonal=1, + ) + y, _ = static_attn( + x, + freqs_cos, + freqs_sin, + mask=mask, + ) + self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) - x = torch.rand(1, config.max_seq_len, config.dim) - freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) - expected, _ = attn_mha(x, freqs_cos, freqs_sin) - mask = torch.triu( - torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), - diagonal=1, - ) - y, _ = static_attn( - x, - freqs_cos, - freqs_sin, - mask=mask, - ) - self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) + test(True) + test(False) def test_hf_rope_without_cache(self): config = ModelArgs(