Skip to content

Commit 929f8f8

Browse files
committed
[Executorch][llama] Enable quantized sdpa
Enable leveraging quantized sdpa op when quantized kv cache is used. Instead of adding yet another arg, at the moment I have chosen to leverage quantize_kv_cache option. Differential Revision: [D71833064](https://our.internmc.facebook.com/intern/diff/D71833064/) ghstack-source-id: 276640303 Pull Request resolved: #9945
1 parent eef2a99 commit 929f8f8

File tree

6 files changed

+483
-20
lines changed

6 files changed

+483
-20
lines changed

examples/models/llama/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,20 @@ runtime.python_test(
274274
":export_library",
275275
],
276276
)
277+
278+
runtime.python_test(
279+
name = "quantized_sdpa_source_transform_test",
280+
srcs = [
281+
"source_transformation/test_quantized_sdpa.py",
282+
],
283+
preload_deps = [
284+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
285+
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
286+
],
287+
deps = [
288+
":custom_kv_cache",
289+
":sdpa",
290+
"//caffe2:torch",
291+
"//executorch/examples/models/llama:llama_transformer",
292+
],
293+
)

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
replace_kv_cache_with_custom_kv_cache,
6464
replace_kv_cache_with_quantized_kv_cache,
6565
)
66+
6667
from .source_transformation.quantize import (
6768
get_quant_embedding_transform,
6869
get_quant_weight_transform,
@@ -77,6 +78,7 @@
7778
replace_sdpa_with_coreml_sdpa,
7879
replace_sdpa_with_custom_op,
7980
replace_sdpa_with_flex_sdpa,
81+
replace_sdpa_with_quantized_sdpa,
8082
replace_sdpa_with_simple_sdpa,
8183
)
8284
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
@@ -1226,11 +1228,14 @@ def _get_source_transforms( # noqa
12261228

12271229
if args.use_sdpa_with_kv_cache:
12281230
transforms.append(replace_kv_cache_with_custom_kv_cache)
1231+
# todo: do this optionally
12291232
transforms.append(replace_sdpa_with_custom_op)
12301233

12311234
if args.quantize_kv_cache:
12321235
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12331236
transforms.append(replace_kv_cache_with_quantized_kv_cache)
1237+
# Right now
1238+
transforms.append(replace_sdpa_with_quantized_sdpa)
12341239

12351240
if args.use_kv_cache:
12361241
if args.qnn:

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(
5252
self.use_custom_update_cache_op = use_custom_update_cache_op
5353
self.quantized_cache_dtype = torch.int8
5454
self.cache_fp_type = torch.float32
55+
self.return_float_values = True
56+
self.max_context_length = max_context_length
5557
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
5658
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
5759
self.register_buffer(
@@ -61,17 +63,17 @@ def __init__(
6163
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
6264
)
6365
self.register_buffer(
64-
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
66+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float32)
6567
)
6668
self.register_buffer(
67-
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
69+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float32)
6870
)
6971
if cache_type == QuantizedCacheType.AffineAsymmetric:
7072
self.register_buffer(
71-
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
73+
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
7274
)
7375
self.register_buffer(
74-
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
76+
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
7577
)
7678

7779
def _quantize(self, value):
@@ -91,20 +93,15 @@ def _quantize(self, value):
9193
)
9294
return quantized_value, scales, zero_points
9395

94-
def update(self, input_pos, k_val, v_val):
95-
"""
96-
k_val, v_val: [B, H, S, D]
97-
return: [B, H, S, D]
98-
However the storage is [B, S, H, D] so we incur transpose in, transpose out
99-
This shall be removed by subsequent post-export graph pass
100-
"""
101-
k_val = k_val.transpose(1, 2)
102-
v_val = v_val.transpose(1, 2)
103-
# quantize current k_val and store it in the cache
96+
def _quantize_and_update(self, input_pos, k_val, v_val):
10497
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
105-
10698
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
10799

100+
k_scales = k_scales.to(torch.float32)
101+
k_zero_points = k_zero_points.to(self.quantized_cache_dtype)
102+
v_scales = v_scales.to(torch.float32)
103+
v_zero_points = v_zero_points.to(self.quantized_cache_dtype)
104+
108105
if self.use_custom_update_cache_op:
109106
start_pos = input_pos[0].item()
110107
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
@@ -125,25 +122,30 @@ def update(self, input_pos, k_val, v_val):
125122
self.v_cache_scales[:, input_pos] = v_scales
126123
self.v_cache_zero_points[:, input_pos] = v_zero_points
127124

125+
def _update_and_return_float_values(self, input_pos, k_val, v_val):
126+
self._quantize_and_update(input_pos, k_val, v_val)
127+
128128
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
129129
self.k_cache,
130-
self.k_cache_scales,
131-
self.k_cache_zero_points,
130+
self.k_cache_scales.to(torch.float64),
131+
self.k_cache_zero_points.to(torch.int64),
132132
torch.iinfo(self.quantized_cache_dtype).min,
133133
torch.iinfo(self.quantized_cache_dtype).max,
134134
self.quantized_cache_dtype,
135135
self.cache_fp_type,
136136
)
137137
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
138138
self.v_cache,
139-
self.v_cache_scales,
140-
self.v_cache_zero_points,
139+
self.v_cache_scales.to(torch.float64),
140+
self.v_cache_zero_points.to(torch.int64),
141141
torch.iinfo(self.quantized_cache_dtype).min,
142142
torch.iinfo(self.quantized_cache_dtype).max,
143143
self.quantized_cache_dtype,
144144
self.cache_fp_type,
145145
)
146146

147+
# When returning float values we jsut use the last value
148+
# instead of dequantized value.
147149
start_pos = input_pos[0].item()
148150
if self.use_custom_update_cache_op:
149151
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
@@ -152,6 +154,29 @@ def update(self, input_pos, k_val, v_val):
152154
k_out[:, input_pos] = k_val
153155
v_out[:, input_pos] = v_val
154156

157+
return k_out, v_out
158+
159+
def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
160+
self._quantize_and_update(input_pos, k_val, v_val)
161+
162+
return self.k_cache, self.v_cache
163+
164+
def update(self, input_pos, k_val, v_val):
165+
"""
166+
k_val, v_val: [B, H, S, D]
167+
return: [B, H, S, D]
168+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
169+
This shall be removed by subsequent post-export graph pass
170+
"""
171+
k_val = k_val.transpose(1, 2)
172+
v_val = v_val.transpose(1, 2)
173+
174+
if self.return_float_values:
175+
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
176+
else:
177+
k_out, v_out = self._update_and_return_quantized_values(
178+
input_pos, k_val, v_val
179+
)
155180
return k_out.transpose(1, 2), v_out.transpose(1, 2)
156181

157182
@classmethod

examples/models/llama/source_transformation/sdpa.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
import torch
1515

16-
from executorch.examples.models.llama.attention import KVCache, SDPA
16+
from executorch.examples.models.llama.attention import Attention, KVCache, SDPA
17+
18+
from .custom_kv_cache import QuantizedKVCache
1719

1820

1921
class SDPACustom(torch.nn.Module):
@@ -76,6 +78,123 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
7678
return module
7779

7880

81+
class QuantizedSDPA(torch.nn.Module):
82+
"""
83+
A quantized version of the SDPA (Scaled Dot Product Attention) module.
84+
85+
This module implements attention computation using quantized key-value pairs
86+
to reduce memory footprint and potentially improve performance. It works with
87+
a QuantizedKVCache to store and retrieve quantized key-value tensors.
88+
89+
The quantization process converts floating point tensors to int8, which requires
90+
maintaining scale and zero point values for proper dequantization during computation.
91+
92+
Args:
93+
dim (int): The dimension of the model
94+
kv_cache (QuantizedKVCache): The cache for storing quantized key-value pairs
95+
Note that it needs to own kv_cache to access scales and zero points, and since
96+
SDPA forward signature only accepts q, k and v, to allow accessing scales and
97+
zero points, we need to pass kv_cache to SDPA.
98+
"""
99+
100+
def __init__(self, dim: int, kv_cache: QuantizedKVCache):
101+
super().__init__()
102+
self.dim = dim
103+
self.quantized_dtype = torch.int8
104+
self.float_dtype = torch.float32
105+
self.kv_cache = kv_cache
106+
107+
def forward(
108+
self,
109+
input_pos: torch.Tensor,
110+
q: torch.Tensor,
111+
k_quantized: torch.Tensor,
112+
v_quantized: torch.Tensor,
113+
bsz,
114+
seqlen,
115+
mask,
116+
):
117+
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
118+
k_quantized = k_quantized.transpose(1, 2)
119+
v_quantized = v_quantized.transpose(1, 2)
120+
121+
q_scale, q_zero_point = (
122+
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
123+
q, self.quantized_dtype
124+
)
125+
)
126+
q_quantized = torch.ops.quantized_decomposed.quantize_per_token(
127+
q,
128+
q_scale,
129+
q_zero_point,
130+
torch.iinfo(self.quantized_dtype).min,
131+
torch.iinfo(self.quantized_dtype).max,
132+
self.quantized_dtype,
133+
)
134+
q_zero_point_int8 = q_zero_point.to(dtype=torch.int8)
135+
q_scale_fp32 = q_scale.to(dtype=torch.float32)
136+
137+
k_zero_point_int8 = self.kv_cache.k_cache_zero_points
138+
k_scale_fp32 = self.kv_cache.k_cache_scales
139+
v_zero_point_int8 = self.kv_cache.v_cache_zero_points
140+
v_scale_fp32 = self.kv_cache.v_cache_scales
141+
142+
start_pos = input_pos[0].item()
143+
output = torch.ops.llama.custom_quantized_sdpa(
144+
q_quantized,
145+
k_quantized,
146+
v_quantized,
147+
start_pos,
148+
None,
149+
0,
150+
True,
151+
None,
152+
q_zero_point_int8,
153+
q_scale_fp32,
154+
k_zero_point_int8,
155+
k_scale_fp32,
156+
v_zero_point_int8,
157+
v_scale_fp32,
158+
)
159+
160+
return output.view(bsz, seqlen, self.dim)
161+
162+
163+
def _update_attention_module_with_quantized_sdpa(
164+
module: torch.nn.Module, kv_cache: QuantizedKVCache
165+
):
166+
sdpa = getattr(module, "SDPA", None)
167+
assert sdpa is not None
168+
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache))
169+
170+
171+
def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module):
172+
for _, child in module.named_children():
173+
if isinstance(child, Attention):
174+
kv_cache = getattr(child, "kv_cache", None)
175+
if kv_cache is None:
176+
continue
177+
if not isinstance(kv_cache, QuantizedKVCache):
178+
continue
179+
# Only when kv_cache is QuantizedKVCache, we replace SDPA with QuantizedSDPA
180+
sdpa = getattr(child, "SDPA", None)
181+
if sdpa is None:
182+
continue
183+
if not isinstance(sdpa, SDPACustom):
184+
continue
185+
kv_cache.return_float_values = False
186+
_update_attention_module_with_quantized_sdpa(child, kv_cache)
187+
else:
188+
_replace_sdpa_with_quantized_sdpa(child)
189+
190+
191+
def replace_sdpa_with_quantized_sdpa(module: torch.nn.Module) -> torch.nn.Module:
192+
from executorch.extension.llm.custom_ops import custom_ops # noqa
193+
194+
_replace_sdpa_with_quantized_sdpa(module)
195+
return module
196+
197+
79198
class SDPASimple(torch.nn.Module):
80199
def __init__(
81200
self,

0 commit comments

Comments
 (0)