Skip to content

[Executorch][llm] Make custom update cache op operate on indices #10834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 82 additions & 25 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from enum import Enum
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -93,7 +93,7 @@ def _quantize(self, value):
)
return quantized_value, scales, zero_points

def _quantize_and_update(self, input_pos, k_val, v_val):
def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

Expand All @@ -104,26 +104,57 @@ def _quantize_and_update(self, input_pos, k_val, v_val):

if self.use_custom_update_cache_op:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)
if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
quantized_k_val, self.k_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
k_scales, self.k_cache_scales, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
k_zero_points, self.k_cache_zero_points, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
quantized_v_val, self.v_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
v_scales, self.v_cache_scales, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
v_zero_points, self.v_cache_zero_points, start_pos, indices
)
else:
_ = torch.ops.llama.update_cache(
quantized_k_val, self.k_cache, start_pos
)
_ = torch.ops.llama.update_cache(
k_scales, self.k_cache_scales, start_pos
)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(
quantized_v_val, self.v_cache, start_pos
)
_ = torch.ops.llama.update_cache(
v_scales, self.v_cache_scales, start_pos
)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)
else:
assert indices is None, "Indices not supported for this path"
# Following is also broken because in prefill input_pos = [0]
# but we need to update some slice of cache
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points

def _update_and_return_float_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)
def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
self._quantize_and_update(input_pos, k_val, v_val, indices)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand All @@ -144,24 +175,34 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
self.cache_fp_type,
)

# When returning float values we jsut use the last value
# When returning float values we just use the last value
# instead of dequantized value.
start_pos = input_pos[0].item()
if self.use_custom_update_cache_op:
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
k_val, k_out, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
v_val, v_out, start_pos, indices
)
else:
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val

return k_out, v_out

def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)
def _update_and_return_quantized_values(
self, input_pos, k_val, v_val, indices=None
):
self._quantize_and_update(input_pos, k_val, v_val, indices)

return self.k_cache, self.v_cache

def update(self, input_pos, k_val, v_val):
def update(self, input_pos, k_val, v_val, indices=None):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
Expand All @@ -172,10 +213,12 @@ def update(self, input_pos, k_val, v_val):
v_val = v_val.transpose(1, 2)

if self.return_float_values:
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
k_out, v_out = self._update_and_return_float_values(
input_pos, k_val, v_val, indices
)
else:
k_out, v_out = self._update_and_return_quantized_values(
input_pos, k_val, v_val
input_pos, k_val, v_val, indices
)
return k_out.transpose(1, 2), v_out.transpose(1, 2)

Expand Down Expand Up @@ -277,14 +320,28 @@ def __init__(
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
self,
input_pos: torch.Tensor,
k_val: torch.Tensor,
v_val: torch.Tensor,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)

if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
k_val, self.k_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache_with_indices(
v_val, self.v_cache, start_pos, indices
)
else:
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)

return (
self.k_cache.transpose(1, 2),
self.v_cache.transpose(1, 2),
Expand Down
56 changes: 45 additions & 11 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _validate_update_cache_params(
value,
cache,
start_pos,
indices=None,
):
seq_len = value.size(1)
assert (
Expand All @@ -200,17 +201,30 @@ def _validate_update_cache_params(
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"

torch._check_is_size(start_pos)
# Setting to arbitrary limit of 256 for now since there is no way
# to plumb this information from model config
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
if indices is None:
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"

if indices is not None:
assert (
indices.dim() == 2
), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions."
assert (
indices.dtype == torch.int64
), f"Expected indices to be int64 but got {indices.dtype}"
assert indices.size(0) == value.size(
0
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
assert indices.size(1) == value.size(
1
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"


@impl(custom_ops_lib, "update_cache", "Meta")
Expand All @@ -231,6 +245,26 @@ def update_cache_meta(
return torch.empty((1,), dtype=value.dtype, device="meta")


@impl(custom_ops_lib, "update_cache_with_indices", "Meta")
def update_cache_with_indices_meta(
value,
cache,
start_pos,
indices,
):
_validate_update_cache_params(
value,
cache,
start_pos,
indices,
)

# Update cache doesnt really return anything but I dont know a better
# workaround. Should we just return cache instead? But I am afraid that
# will result in extra memory allocation
return torch.empty((1,), dtype=value.dtype, device="meta")


def _validate_quantized_sdpa_params(
query,
key,
Expand Down
51 changes: 51 additions & 0 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ at::Tensor update_cache_aten(
at::Tensor& cache,
const int64_t start_pos);

// New functions for update_cache_with_indices
Tensor& update_cache_with_indices_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const Tensor& indices,
Tensor& output);

at::Tensor update_cache_with_indices_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos,
const at::Tensor& indices);

Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
const Tensor& k_projected,
Expand Down Expand Up @@ -340,6 +354,29 @@ at::Tensor update_cache_aten(
return output;
}

// Implementations for update_cache_with_indices
Tensor& update_cache_with_indices_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const Tensor& indices,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::update_cache_with_indices_out(
context, value, cache, start_pos, indices, output);
}

at::Tensor update_cache_with_indices_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos,
const at::Tensor& indices) {
auto output = at::empty({1});
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
(value, cache, start_pos, indices, output);
return output;
}

} // namespace native
} // namespace executor
} // namespace torch
Expand Down Expand Up @@ -367,6 +404,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
m.def(
"update_cache.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
m.def(
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, Tensor indices) -> Tensor");
m.def(
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
m.def(
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
Expand Down Expand Up @@ -397,6 +440,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"update_cache.out",
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
m.impl(
"update_cache_with_indices",
torch::executor::native::update_cache_with_indices_aten);
m.impl(
"update_cache_with_indices.out",
WRAP_TO_ATEN(
torch::executor::native::update_cache_with_indices_out_no_context,
4));
m.impl(
"custom_quantized_sdpa",
torch::executor::native::custom_quantized_sdpa_aten);
Expand Down
Loading
Loading