Skip to content

Commit efb7cf3

Browse files
authored
Add a simple sdpa (#3037) (#3166)
Summary: Pull Request resolved: #3037 Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where` ``` def forward(self, q, k, v): aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605); q = None aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2); aten_arange_start_step = None aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1); aten_arange_start_step_1 = None aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1); aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0); aten_sub_tensor = None aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default); aten_le_scalar = aten_full_default = None aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format) aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default); aten_logical_and_default = None aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default); aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]); k = None aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605); aten_permute_copy_default = None aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]); aten_mul_scalar = None aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]); aten_expand_copy_default = None aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]); aten_mul_scalar_1 = None aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]); aten_expand_copy_default_1 = None aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1); aten_view_copy_default = aten_view_copy_default_1 = None aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]); aten_bmm_default = None aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self); aten_view_copy_default_2 = aten_where_self = None aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False); aten_add_tensor = None aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]); aten__softmax_default = None aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]); aten_expand_copy_default_2 = None aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]); v = None aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]); aten_expand_copy_default_3 = None aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4); aten_view_copy_default_3 = aten_view_copy_default_4 = None aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]); aten_bmm_default_1 = None return (aten_view_copy_default_5,) ``` After applying the diff, we remove the following ops ``` %aten_full_like_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full_like.default](args = (%aten_index_tensor_2, 0), kwargs = {dtype: torch.float32, pin_memory: False, memory_format: torch.preserve_format}) %aten_logical_not_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.logical_not.default](args = (%aten_index_tensor_2,), kwargs = {}) %aten_scalar_tensor_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.scalar_tensor.default](args = (-inf,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu}) %aten_where_self : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default, %aten_scalar_tensor_default, %aten_full_like_default), kwargs = {}) %aten_mul_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_3, 0.5946035575013605), kwargs = {}) ... %aten_mul_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_6, 0.5946035575013605), kwargs = {}) ``` but introduce an add %aten_add_tensor_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mul_tensor_11, %aten_index_tensor_2), kwargs = {}) ``` ghstack-source-id: 223152096 exported-using-ghexport Reviewed By: mergennachin, kimishpatel Differential Revision: D56119737 fbshipit-source-id: ec8e875f0a4c4ec67b7493e4872c9a5b081e6de7 (cherry picked from commit cf78107)
1 parent aa3f22c commit efb7cf3

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import argparse
1010
import copy
1111
import logging
12+
import math
1213
import os
1314
import shlex
1415

@@ -143,6 +144,80 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
143144
return module
144145

145146

147+
class SDPASimple(torch.nn.Module):
148+
149+
def __init__(
150+
self,
151+
kv_cache: KVCache,
152+
dim: int,
153+
head_dim: int,
154+
n_rep: int,
155+
):
156+
super().__init__()
157+
self.kv_cache = kv_cache
158+
self.dim = dim
159+
self.head_dim = head_dim
160+
self.n_rep = n_rep
161+
162+
def forward(
163+
self,
164+
input_pos: torch.Tensor,
165+
q: torch.Tensor,
166+
k: torch.Tensor,
167+
v: torch.Tensor,
168+
bsz,
169+
seqlen,
170+
mask,
171+
):
172+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
173+
k = k.transpose(1, 2)
174+
v = v.transpose(1, 2)
175+
176+
k, v = self.kv_cache.update(input_pos, k, v)
177+
attn_mask = mask[None, None, input_pos]
178+
179+
k = k.repeat_interleave(self.n_rep, dim=1)
180+
v = v.repeat_interleave(self.n_rep, dim=1)
181+
scale_factor = 1 / math.sqrt(q.size(-1))
182+
attn_weight = q @ k.transpose(-2, -1) * scale_factor
183+
attn_weight += attn_mask
184+
attn_weight = torch.softmax(attn_weight, dim=-1)
185+
y = attn_weight @ v
186+
187+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
188+
189+
190+
def replace_sdpa_with_simple_sdpa(module: torch.nn.Module):
191+
for name, child in module.named_children():
192+
if isinstance(child, SDPA):
193+
setattr(
194+
module,
195+
name,
196+
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep),
197+
)
198+
else:
199+
replace_sdpa_with_simple_sdpa(child)
200+
return module
201+
202+
203+
def replace_causal_mask(module: torch.nn.Module):
204+
for buffer_fqn_name, buffer in module.named_buffers():
205+
buffer_name = buffer_fqn_name.split(".")[-1]
206+
if buffer_name == "mask":
207+
max_seq_len = buffer.shape[-1]
208+
mask = torch.full(
209+
(max_seq_len, max_seq_len),
210+
float("-inf"),
211+
device="cpu",
212+
)
213+
214+
mask = torch.triu(mask, diagonal=1)
215+
module.register_buffer(buffer_name, mask)
216+
for _, child in module.named_children():
217+
replace_causal_mask(child)
218+
return module
219+
220+
146221
def quantize(
147222
model: torch.nn.Module,
148223
qmode: str,

examples/models/llama2/tests/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
oncall("executorch")
4+
5+
python_unittest(
6+
name = "test_simple_sdpa",
7+
srcs = [
8+
"test_simple_sdpa.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
"//executorch/examples/models/llama2:export_library",
13+
"//executorch/examples/models/llama2:llama_transformer",
14+
],
15+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
from executorch.examples.models.llama2.export_llama_lib import SDPASimple
12+
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
13+
14+
15+
class SDPATest(unittest.TestCase):
16+
def test_simple_sdpa(self):
17+
# Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py
18+
max_batch_size = 1
19+
max_seq_length = 128
20+
n_heads = 8
21+
head_dim = 8
22+
dim = 64
23+
n_rep = 1
24+
bsz = 1
25+
seqlen = 1
26+
n_local_heads = n_heads
27+
kv_cache = KVCache(
28+
max_batch_size=max_batch_size,
29+
max_seq_length=max_seq_length,
30+
n_heads=n_heads,
31+
head_dim=head_dim,
32+
transpose_cache=True,
33+
)
34+
sdpa = SDPA(
35+
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
36+
)
37+
input_pos = torch.tensor([0])
38+
query = torch.randn(1, 1, n_local_heads, head_dim)
39+
key = torch.randn(1, 1, n_local_heads, head_dim)
40+
value = torch.randn(1, 1, n_local_heads, head_dim)
41+
mask = torch.randn(max_seq_length, max_seq_length)
42+
sdpa_output = sdpa(
43+
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
44+
)
45+
46+
simple_sdpa = SDPASimple(
47+
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
48+
)
49+
simple_sdpa_output = simple_sdpa(
50+
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
51+
)
52+
53+
# Compare the output from output from two sdpa implementation
54+
self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output))

0 commit comments

Comments
 (0)