Skip to content

Commit b0791b4

Browse files
authored
[Executorch][llm] Make custom update cache op operate on indices
Differential Revision: D73891424 Pull Request resolved: #10610
1 parent 4935c16 commit b0791b4

File tree

6 files changed

+598
-74
lines changed

6 files changed

+598
-74
lines changed

examples/models/llama/source_transformation/custom_kv_cache.py

+82-25
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from enum import Enum
9-
from typing import Tuple
9+
from typing import Optional, Tuple
1010

1111
import torch
1212
import torch.nn as nn
@@ -93,7 +93,7 @@ def _quantize(self, value):
9393
)
9494
return quantized_value, scales, zero_points
9595

96-
def _quantize_and_update(self, input_pos, k_val, v_val):
96+
def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
9797
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
9898
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
9999

@@ -104,26 +104,57 @@ def _quantize_and_update(self, input_pos, k_val, v_val):
104104

105105
if self.use_custom_update_cache_op:
106106
start_pos = input_pos[0].item()
107-
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
108-
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
109-
_ = torch.ops.llama.update_cache(
110-
k_zero_points, self.k_cache_zero_points, start_pos
111-
)
112-
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
113-
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
114-
_ = torch.ops.llama.update_cache(
115-
v_zero_points, self.v_cache_zero_points, start_pos
116-
)
107+
if indices is not None:
108+
_ = torch.ops.llama.update_cache_with_indices(
109+
quantized_k_val, self.k_cache, start_pos, indices
110+
)
111+
_ = torch.ops.llama.update_cache_with_indices(
112+
k_scales, self.k_cache_scales, start_pos, indices
113+
)
114+
_ = torch.ops.llama.update_cache_with_indices(
115+
k_zero_points, self.k_cache_zero_points, start_pos, indices
116+
)
117+
_ = torch.ops.llama.update_cache_with_indices(
118+
quantized_v_val, self.v_cache, start_pos, indices
119+
)
120+
_ = torch.ops.llama.update_cache_with_indices(
121+
v_scales, self.v_cache_scales, start_pos, indices
122+
)
123+
_ = torch.ops.llama.update_cache_with_indices(
124+
v_zero_points, self.v_cache_zero_points, start_pos, indices
125+
)
126+
else:
127+
_ = torch.ops.llama.update_cache(
128+
quantized_k_val, self.k_cache, start_pos
129+
)
130+
_ = torch.ops.llama.update_cache(
131+
k_scales, self.k_cache_scales, start_pos
132+
)
133+
_ = torch.ops.llama.update_cache(
134+
k_zero_points, self.k_cache_zero_points, start_pos
135+
)
136+
_ = torch.ops.llama.update_cache(
137+
quantized_v_val, self.v_cache, start_pos
138+
)
139+
_ = torch.ops.llama.update_cache(
140+
v_scales, self.v_cache_scales, start_pos
141+
)
142+
_ = torch.ops.llama.update_cache(
143+
v_zero_points, self.v_cache_zero_points, start_pos
144+
)
117145
else:
146+
assert indices is None, "Indices not supported for this path"
147+
# Following is also broken because in prefill input_pos = [0]
148+
# but we need to update some slice of cache
118149
self.k_cache[:, input_pos] = quantized_k_val
119150
self.k_cache_scales[:, input_pos] = k_scales
120151
self.k_cache_zero_points[:, input_pos] = k_zero_points
121152
self.v_cache[:, input_pos] = quantized_v_val
122153
self.v_cache_scales[:, input_pos] = v_scales
123154
self.v_cache_zero_points[:, input_pos] = v_zero_points
124155

125-
def _update_and_return_float_values(self, input_pos, k_val, v_val):
126-
self._quantize_and_update(input_pos, k_val, v_val)
156+
def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
157+
self._quantize_and_update(input_pos, k_val, v_val, indices)
127158

128159
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
129160
self.k_cache,
@@ -144,24 +175,34 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
144175
self.cache_fp_type,
145176
)
146177

147-
# When returning float values we jsut use the last value
178+
# When returning float values we just use the last value
148179
# instead of dequantized value.
149180
start_pos = input_pos[0].item()
150181
if self.use_custom_update_cache_op:
151-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
152-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
182+
if indices is not None:
183+
_ = torch.ops.llama.update_cache_with_indices(
184+
k_val, k_out, start_pos, indices
185+
)
186+
_ = torch.ops.llama.update_cache_with_indices(
187+
v_val, v_out, start_pos, indices
188+
)
189+
else:
190+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
191+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
153192
else:
154193
k_out[:, input_pos] = k_val
155194
v_out[:, input_pos] = v_val
156195

157196
return k_out, v_out
158197

159-
def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
160-
self._quantize_and_update(input_pos, k_val, v_val)
198+
def _update_and_return_quantized_values(
199+
self, input_pos, k_val, v_val, indices=None
200+
):
201+
self._quantize_and_update(input_pos, k_val, v_val, indices)
161202

162203
return self.k_cache, self.v_cache
163204

164-
def update(self, input_pos, k_val, v_val):
205+
def update(self, input_pos, k_val, v_val, indices=None):
165206
"""
166207
k_val, v_val: [B, H, S, D]
167208
return: [B, H, S, D]
@@ -172,10 +213,12 @@ def update(self, input_pos, k_val, v_val):
172213
v_val = v_val.transpose(1, 2)
173214

174215
if self.return_float_values:
175-
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
216+
k_out, v_out = self._update_and_return_float_values(
217+
input_pos, k_val, v_val, indices
218+
)
176219
else:
177220
k_out, v_out = self._update_and_return_quantized_values(
178-
input_pos, k_val, v_val
221+
input_pos, k_val, v_val, indices
179222
)
180223
return k_out.transpose(1, 2), v_out.transpose(1, 2)
181224

@@ -277,14 +320,28 @@ def __init__(
277320
)
278321

279322
def update(
280-
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
323+
self,
324+
input_pos: torch.Tensor,
325+
k_val: torch.Tensor,
326+
v_val: torch.Tensor,
327+
indices: Optional[torch.Tensor] = None,
281328
) -> Tuple[torch.Tensor, torch.Tensor]:
282329
# input_pos: [S], k_val: [B, H, S, D]
283330
k_val = k_val.transpose(1, 2)
284331
v_val = v_val.transpose(1, 2)
285332
start_pos = input_pos[0].item()
286-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
287-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
333+
334+
if indices is not None:
335+
_ = torch.ops.llama.update_cache_with_indices(
336+
k_val, self.k_cache, start_pos, indices
337+
)
338+
_ = torch.ops.llama.update_cache_with_indices(
339+
v_val, self.v_cache, start_pos, indices
340+
)
341+
else:
342+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
343+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
344+
288345
return (
289346
self.k_cache.transpose(1, 2),
290347
self.v_cache.transpose(1, 2),

extension/llm/custom_ops/custom_ops.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _validate_update_cache_params(
184184
value,
185185
cache,
186186
start_pos,
187+
indices=None,
187188
):
188189
seq_len = value.size(1)
189190
assert (
@@ -200,17 +201,30 @@ def _validate_update_cache_params(
200201
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
201202

202203
torch._check_is_size(start_pos)
203-
# Setting to arbitrary limit of 256 for now since there is no way
204-
# to plumb this information from model config
205-
torch._check(start_pos < cache.size(1))
206-
assert start_pos < cache.size(
207-
1
208-
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209-
210-
torch._check((start_pos + seq_len) < cache.size(1))
211-
assert (start_pos + seq_len) < cache.size(
212-
1
213-
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
204+
if indices is None:
205+
torch._check(start_pos < cache.size(1))
206+
assert start_pos < cache.size(
207+
1
208+
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209+
210+
torch._check((start_pos + seq_len) < cache.size(1))
211+
assert (start_pos + seq_len) < cache.size(
212+
1
213+
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
214+
215+
if indices is not None:
216+
assert (
217+
indices.dim() == 2
218+
), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions."
219+
assert (
220+
indices.dtype == torch.int64
221+
), f"Expected indices to be int64 but got {indices.dtype}"
222+
assert indices.size(0) == value.size(
223+
0
224+
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
225+
assert indices.size(1) == value.size(
226+
1
227+
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"
214228

215229

216230
@impl(custom_ops_lib, "update_cache", "Meta")
@@ -231,6 +245,26 @@ def update_cache_meta(
231245
return torch.empty((1,), dtype=value.dtype, device="meta")
232246

233247

248+
@impl(custom_ops_lib, "update_cache_with_indices", "Meta")
249+
def update_cache_with_indices_meta(
250+
value,
251+
cache,
252+
start_pos,
253+
indices,
254+
):
255+
_validate_update_cache_params(
256+
value,
257+
cache,
258+
start_pos,
259+
indices,
260+
)
261+
262+
# Update cache doesnt really return anything but I dont know a better
263+
# workaround. Should we just return cache instead? But I am afraid that
264+
# will result in extra memory allocation
265+
return torch.empty((1,), dtype=value.dtype, device="meta")
266+
267+
234268
def _validate_quantized_sdpa_params(
235269
query,
236270
key,

extension/llm/custom_ops/op_sdpa_aot.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,20 @@ at::Tensor update_cache_aten(
129129
at::Tensor& cache,
130130
const int64_t start_pos);
131131

132+
// New functions for update_cache_with_indices
133+
Tensor& update_cache_with_indices_out_no_context(
134+
const Tensor& value,
135+
Tensor& cache,
136+
const int64_t start_pos,
137+
const Tensor& indices,
138+
Tensor& output);
139+
140+
at::Tensor update_cache_with_indices_aten(
141+
const at::Tensor& value,
142+
at::Tensor& cache,
143+
const int64_t start_pos,
144+
const at::Tensor& indices);
145+
132146
Tensor& sdpa_with_kv_cache_out_no_context(
133147
const Tensor& q_projected,
134148
const Tensor& k_projected,
@@ -340,6 +354,29 @@ at::Tensor update_cache_aten(
340354
return output;
341355
}
342356

357+
// Implementations for update_cache_with_indices
358+
Tensor& update_cache_with_indices_out_no_context(
359+
const Tensor& value,
360+
Tensor& cache,
361+
const int64_t start_pos,
362+
const Tensor& indices,
363+
Tensor& output) {
364+
executorch::aten::RuntimeContext context{};
365+
return torch::executor::native::update_cache_with_indices_out(
366+
context, value, cache, start_pos, indices, output);
367+
}
368+
369+
at::Tensor update_cache_with_indices_aten(
370+
const at::Tensor& value,
371+
at::Tensor& cache,
372+
const int64_t start_pos,
373+
const at::Tensor& indices) {
374+
auto output = at::empty({1});
375+
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
376+
(value, cache, start_pos, indices, output);
377+
return output;
378+
}
379+
343380
} // namespace native
344381
} // namespace executor
345382
} // namespace torch
@@ -367,6 +404,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
367404
m.def(
368405
"update_cache.out(Tensor value, Tensor(a!) cache, "
369406
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
407+
m.def(
408+
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
409+
"SymInt start_pos, Tensor indices) -> Tensor");
410+
m.def(
411+
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
412+
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
370413
m.def(
371414
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
372415
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -397,6 +440,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
397440
m.impl(
398441
"update_cache.out",
399442
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
443+
m.impl(
444+
"update_cache_with_indices",
445+
torch::executor::native::update_cache_with_indices_aten);
446+
m.impl(
447+
"update_cache_with_indices.out",
448+
WRAP_TO_ATEN(
449+
torch::executor::native::update_cache_with_indices_out_no_context,
450+
4));
400451
m.impl(
401452
"custom_quantized_sdpa",
402453
torch::executor::native::custom_quantized_sdpa_aten);

0 commit comments

Comments
 (0)