diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 1158a8ba7a6..4674074f8a5 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -6,7 +6,7 @@ import logging from enum import Enum -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -93,7 +93,7 @@ def _quantize(self, value): ) return quantized_value, scales, zero_points - def _quantize_and_update(self, input_pos, k_val, v_val): + def _quantize_and_update(self, input_pos, k_val, v_val, indices=None): quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) @@ -104,17 +104,48 @@ def _quantize_and_update(self, input_pos, k_val, v_val): if self.use_custom_update_cache_op: start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - k_zero_points, self.k_cache_zero_points, start_pos - ) - _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) - _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - v_zero_points, self.v_cache_zero_points, start_pos - ) + if indices is not None: + _ = torch.ops.llama.update_cache_with_indices( + quantized_k_val, self.k_cache, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + k_scales, self.k_cache_scales, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + k_zero_points, self.k_cache_zero_points, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + quantized_v_val, self.v_cache, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + v_scales, self.v_cache_scales, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + v_zero_points, self.v_cache_zero_points, start_pos, indices + ) + else: + _ = torch.ops.llama.update_cache( + quantized_k_val, self.k_cache, start_pos + ) + _ = torch.ops.llama.update_cache( + k_scales, self.k_cache_scales, start_pos + ) + _ = torch.ops.llama.update_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_cache( + quantized_v_val, self.v_cache, start_pos + ) + _ = torch.ops.llama.update_cache( + v_scales, self.v_cache_scales, start_pos + ) + _ = torch.ops.llama.update_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) else: + assert indices is None, "Indices not supported for this path" + # Following is also broken because in prefill input_pos = [0] + # but we need to update some slice of cache self.k_cache[:, input_pos] = quantized_k_val self.k_cache_scales[:, input_pos] = k_scales self.k_cache_zero_points[:, input_pos] = k_zero_points @@ -122,8 +153,8 @@ def _quantize_and_update(self, input_pos, k_val, v_val): self.v_cache_scales[:, input_pos] = v_scales self.v_cache_zero_points[:, input_pos] = v_zero_points - def _update_and_return_float_values(self, input_pos, k_val, v_val): - self._quantize_and_update(input_pos, k_val, v_val) + def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None): + self._quantize_and_update(input_pos, k_val, v_val, indices) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, @@ -144,24 +175,34 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val): self.cache_fp_type, ) - # When returning float values we jsut use the last value + # When returning float values we just use the last value # instead of dequantized value. start_pos = input_pos[0].item() if self.use_custom_update_cache_op: - _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + if indices is not None: + _ = torch.ops.llama.update_cache_with_indices( + k_val, k_out, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + v_val, v_out, start_pos, indices + ) + else: + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) else: k_out[:, input_pos] = k_val v_out[:, input_pos] = v_val return k_out, v_out - def _update_and_return_quantized_values(self, input_pos, k_val, v_val): - self._quantize_and_update(input_pos, k_val, v_val) + def _update_and_return_quantized_values( + self, input_pos, k_val, v_val, indices=None + ): + self._quantize_and_update(input_pos, k_val, v_val, indices) return self.k_cache, self.v_cache - def update(self, input_pos, k_val, v_val): + def update(self, input_pos, k_val, v_val, indices=None): """ k_val, v_val: [B, H, S, D] return: [B, H, S, D] @@ -172,10 +213,12 @@ def update(self, input_pos, k_val, v_val): v_val = v_val.transpose(1, 2) if self.return_float_values: - k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val) + k_out, v_out = self._update_and_return_float_values( + input_pos, k_val, v_val, indices + ) else: k_out, v_out = self._update_and_return_quantized_values( - input_pos, k_val, v_val + input_pos, k_val, v_val, indices ) return k_out.transpose(1, 2), v_out.transpose(1, 2) @@ -277,14 +320,28 @@ def __init__( ) def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] k_val = k_val.transpose(1, 2) v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) + + if indices is not None: + _ = torch.ops.llama.update_cache_with_indices( + k_val, self.k_cache, start_pos, indices + ) + _ = torch.ops.llama.update_cache_with_indices( + v_val, self.v_cache, start_pos, indices + ) + else: + _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) + _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) + return ( self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2), diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 6d96a926497..947eae6c0d0 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -184,6 +184,7 @@ def _validate_update_cache_params( value, cache, start_pos, + indices=None, ): seq_len = value.size(1) assert ( @@ -200,17 +201,30 @@ def _validate_update_cache_params( ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}" torch._check_is_size(start_pos) - # Setting to arbitrary limit of 256 for now since there is no way - # to plumb this information from model config - torch._check(start_pos < cache.size(1)) - assert start_pos < cache.size( - 1 - ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" - - torch._check((start_pos + seq_len) < cache.size(1)) - assert (start_pos + seq_len) < cache.size( - 1 - ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + if indices is None: + torch._check(start_pos < cache.size(1)) + assert start_pos < cache.size( + 1 + ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" + + torch._check((start_pos + seq_len) < cache.size(1)) + assert (start_pos + seq_len) < cache.size( + 1 + ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + + if indices is not None: + assert ( + indices.dim() == 2 + ), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions." + assert ( + indices.dtype == torch.int64 + ), f"Expected indices to be int64 but got {indices.dtype}" + assert indices.size(0) == value.size( + 0 + ), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}" + assert indices.size(1) == value.size( + 1 + ), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}" @impl(custom_ops_lib, "update_cache", "Meta") @@ -231,6 +245,26 @@ def update_cache_meta( return torch.empty((1,), dtype=value.dtype, device="meta") +@impl(custom_ops_lib, "update_cache_with_indices", "Meta") +def update_cache_with_indices_meta( + value, + cache, + start_pos, + indices, +): + _validate_update_cache_params( + value, + cache, + start_pos, + indices, + ) + + # Update cache doesnt really return anything but I dont know a better + # workaround. Should we just return cache instead? But I am afraid that + # will result in extra memory allocation + return torch.empty((1,), dtype=value.dtype, device="meta") + + def _validate_quantized_sdpa_params( query, key, diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index ff367c85c8a..5bbf22d336e 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -129,6 +129,20 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos); +// New functions for update_cache_with_indices +Tensor& update_cache_with_indices_out_no_context( + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + const Tensor& indices, + Tensor& output); + +at::Tensor update_cache_with_indices_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices); + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -340,6 +354,29 @@ at::Tensor update_cache_aten( return output; } +// Implementations for update_cache_with_indices +Tensor& update_cache_with_indices_out_no_context( + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + const Tensor& indices, + Tensor& output) { + executorch::aten::RuntimeContext context{}; + return torch::executor::native::update_cache_with_indices_out( + context, value, cache, start_pos, indices, output); +} + +at::Tensor update_cache_with_indices_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices) { + auto output = at::empty({1}); + WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4) + (value, cache, start_pos, indices, output); + return output; +} + } // namespace native } // namespace executor } // namespace torch @@ -367,6 +404,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); + m.def( + "update_cache_with_indices(Tensor value, Tensor(a!) cache, " + "SymInt start_pos, Tensor indices) -> Tensor"); + m.def( + "update_cache_with_indices.out(Tensor value, Tensor(a!) cache, " + "SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -397,6 +440,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "update_cache.out", WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); + m.impl( + "update_cache_with_indices", + torch::executor::native::update_cache_with_indices_aten); + m.impl( + "update_cache_with_indices.out", + WRAP_TO_ATEN( + torch::executor::native::update_cache_with_indices_out_no_context, + 4)); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); diff --git a/extension/llm/custom_ops/op_update_cache.cpp b/extension/llm/custom_ops/op_update_cache.cpp index 323b7a65ddb..7ab994deb5f 100644 --- a/extension/llm/custom_ops/op_update_cache.cpp +++ b/extension/llm/custom_ops/op_update_cache.cpp @@ -20,30 +20,58 @@ namespace executor { namespace native { namespace { +// Helper function to validate cache parameters bool validate_cache_params( const Tensor& quantized_value, const Tensor& quantized_cache, int64_t start_pos, - int64_t seq_length) { + int64_t seq_length, + const optional& indices = nullopt) { ET_CHECK_OR_RETURN_FALSE( quantized_cache.dim() == 4, "quantized cache must be a 4D tensor"); ET_CHECK_OR_RETURN_FALSE( quantized_value.dim() == 4, "quantized_value must be a 4D tensor"); - ET_CHECK_OR_RETURN_FALSE( - start_pos < quantized_cache.size(1), - "start_pos must be less than cache size at dim 1"); + if (indices.has_value()) { + const auto& indices_tensor = indices.value(); + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.dim() == 2, + "indices must be a 2D tensor [batch_size, seq_len]"); - ET_CHECK_OR_RETURN_FALSE( - (start_pos + seq_length) <= quantized_cache.size(1), - "start_post + seq_length must be less than max seq length supported by cache." - "start pos: %" PRId64 ", seq_length: %" PRId64 - "." - "cache size: %zd", - start_pos, - seq_length, - quantized_cache.size(1)); + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.size(0) == quantized_value.size(0), + "indices batch dimension must match value batch dimension"); + + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.size(1) == quantized_value.size(1), + "indices sequence length dimension must match value sequence length dimension"); + + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.scalar_type() == ScalarType::Long, + "indices must be of Long (int64_t) type"); + + ET_CHECK_OR_RETURN_FALSE( + is_contiguous_dim_order( + indices_tensor.dim_order().data(), indices_tensor.dim()), + "indices must be in contiguous dim order"); + } else { + ET_CHECK_OR_RETURN_FALSE( + start_pos < quantized_cache.size(1), + "start_pos: %" PRId64 " must be less than cache size at dim 1: %zd", + start_pos, + quantized_cache.size(1)); + + ET_CHECK_OR_RETURN_FALSE( + (start_pos + seq_length) <= quantized_cache.size(1), + "start_post + seq_length must be less than max seq length supported by cache." + "start pos: %" PRId64 ", seq_length: %" PRId64 + "." + "cache size: %zd", + start_pos, + seq_length, + quantized_cache.size(1)); + } // Make sure they are in contiguous dim order ET_CHECK_OR_RETURN_FALSE( @@ -58,34 +86,37 @@ bool validate_cache_params( return true; } -} // anonymous namespace -Tensor& update_cache_out( +// Helper function for the actual update operation +Tensor& update_cache_impl( RuntimeContext& ctx, const Tensor& value, Tensor& cache, const int64_t start_pos, - Tensor& output) { + Tensor& output, + const optional& indices = nullopt) { (void)ctx; - int64_t seq_len = value.size(1); - ET_KERNEL_CHECK( - ctx, - validate_cache_params(value, cache, start_pos, seq_len), - InvalidArgument, - output); ET_CHECK_MSG( value.size(0) == cache.size(0), - "projected_value batch size should be equal to the cache batch size."); + "projected_value batch size (%zd) should be equal to the cache batch size (%zd).", + value.size(0), + cache.size(0)); ET_CHECK_MSG( value.size(2) == cache.size(2), - "projected_value number of heads should be equal to the cache number of heads."); + "projected_value number of heads (%zd) should be equal to the cache number of heads (%zd).", + value.size(2), + cache.size(2)); ET_CHECK_MSG( value.size(3) == cache.size(3), - "projected_value embedding dimension should be equal to the cache embedding dimension."); + "projected_value embedding dimension (%zd) should be equal to the cache embedding dimension (%zd).", + value.size(3), + cache.size(3)); ET_CHECK_MSG( value.element_size() == cache.element_size(), - "projected_value data type size should be equal to the cache data type size."); + "projected_value data type size (%zd) should be equal to the cache data type size (%zd).", + value.element_size(), + cache.element_size()); ET_CHECK_MSG( is_contiguous_dim_order(value.dim_order().data(), value.dim()), @@ -110,23 +141,107 @@ Tensor& update_cache_out( executorch::aten::SizesType num_bytes_to_copy = (value.numel() / value.size(0)) * value.element_size(); - for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { - executorch::aten::SizesType cache_pos_offset = - (batch_line * cache_batch_dim_stride + - start_pos * cache_seq_dim_stride) * - cache.element_size(); - executorch::aten::SizesType value_pos_offset = - (batch_line * value_batch_dim_stride) * cache.element_size(); - - std::memcpy( - (uint8_t*)cache_data + cache_pos_offset, - (uint8_t*)value_data + value_pos_offset, - num_bytes_to_copy); + if (indices.has_value()) { + // Use the provided indices tensor for each batch and sequence position + const Tensor& indices_tensor = indices.value(); + const int64_t* indices_data = + static_cast(indices_tensor.const_data_ptr()); + auto indices_strides = indices_tensor.strides(); + executorch::aten::StridesType indices_batch_stride = indices_strides[0]; + executorch::aten::StridesType indices_seq_stride = indices_strides[1]; + + // Calculate bytes to copy for a single token + executorch::aten::SizesType bytes_per_token = + (value.numel() / (value.size(0) * value.size(1))) * + value.element_size(); + + for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { + for (int64_t seq_idx = 0; seq_idx < value.size(1); ++seq_idx) { + // Get the target position from the indices tensor + int64_t target_pos = indices_data + [batch_line * indices_batch_stride + seq_idx * indices_seq_stride]; + + // Ensure the target position is valid + ET_CHECK_MSG( + target_pos >= 0 && target_pos < cache.size(1), + "Index out of bounds: %" PRId64 " not in [0, %zd)", + target_pos, + cache.size(1)); + + // Calculate offsets for cache and value + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + target_pos * cache_seq_dim_stride) * + cache.element_size(); + + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride + seq_idx * value_strides[1]) * + value.element_size(); + + // Copy a single token + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + bytes_per_token); + } + } + } else { + // Use the original implementation with start_pos + for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + start_pos * cache_seq_dim_stride) * + cache.element_size(); + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + num_bytes_to_copy); + } } // Noone uses output. Just a placeholder. return output; } +} // anonymous namespace + +// Original update_cache_out function without indices parameter +Tensor& update_cache_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + Tensor& output) { + int64_t seq_len = value.size(1); + ET_KERNEL_CHECK( + ctx, + validate_cache_params(value, cache, start_pos, seq_len), + InvalidArgument, + output); + + return update_cache_impl(ctx, value, cache, start_pos, output); +} + +// New function that explicitly takes indices +Tensor& update_cache_with_indices_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + const Tensor& indices, + Tensor& output) { + int64_t seq_len = value.size(1); + ET_KERNEL_CHECK( + ctx, + validate_cache_params(value, cache, start_pos, seq_len, indices), + InvalidArgument, + output); + + return update_cache_impl(ctx, value, cache, start_pos, output, indices); +} + } // namespace native } // namespace executor } // namespace torch @@ -141,3 +256,9 @@ EXECUTORCH_LIBRARY( llama, "update_cache.out", torch::executor::native::update_cache_out); + +// Register the new update_cache_with_indices.out op +EXECUTORCH_LIBRARY( + llama, + "update_cache_with_indices.out", + torch::executor::native::update_cache_with_indices_out); diff --git a/extension/llm/custom_ops/op_update_cache.h b/extension/llm/custom_ops/op_update_cache.h index cf518b4e108..84c73039469 100644 --- a/extension/llm/custom_ops/op_update_cache.h +++ b/extension/llm/custom_ops/op_update_cache.h @@ -15,12 +15,22 @@ namespace executor { namespace native { +// Original update_cache_out function without indices parameter Tensor& update_cache_out( RuntimeContext& ctx, const Tensor& value, Tensor& cache, const int64_t start_pos, Tensor& output); + +// New function that explicitly takes indices +Tensor& update_cache_with_indices_out( + RuntimeContext& ctx, + const Tensor& value, + Tensor& cache, + const int64_t start_pos, + const Tensor& indices, + Tensor& output); } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 1d2f392c129..78c30d5f8b7 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -6,11 +6,28 @@ # pyre-unsafe +import multiprocessing import unittest import torch +def run_in_subprocess(target): + """ + Decorator to run the target function in a separate subprocess + so as to allow cpp code to throw runtime::abort + """ + + def wrapper(*args, **kwargs): + p = multiprocessing.Process(target=target, args=args, kwargs=kwargs) + p.start() + p.join() + if p.exitcode != 0: + raise Exception(f"Subprocess failed with exit code {p.exitcode}") + + return wrapper + + class UpdateQuantizedKVCacheTest(unittest.TestCase): def _reset(self): @@ -82,6 +99,38 @@ def _update_and_validate( self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache)) + def _update_with_indices_and_validate( + self, k, k_scales, k_zero_points, start_pos, indices + ): + k_cache = self.quantized_k_cache.clone() + k_scales_cache = self.k_scales_cache.clone() + k_zero_points_cache = self.k_zero_points_cache.clone() + + # Update using Python indexing for reference + for batch_idx in range(self.batch_size): + for seq_idx in range(indices.size(1)): + idx = indices[batch_idx, seq_idx].item() + if idx >= 0 and idx < self.seq_len: + self.quantized_k_cache[batch_idx, idx] = k[batch_idx, seq_idx] + self.k_scales_cache[batch_idx, idx] = k_scales[batch_idx, seq_idx] + self.k_zero_points_cache[batch_idx, idx] = k_zero_points[ + batch_idx, seq_idx + ] + + # Update using custom op + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices( + k_scales, k_scales_cache, start_pos, indices + ) + torch.ops.llama.update_cache_with_indices( + k_zero_points, k_zero_points_cache, start_pos, indices + ) + + # Validate results + self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) + self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache)) + self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) + def test_update_kv_cache_simple(self): k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) @@ -94,6 +143,208 @@ def test_update_kv_cache_simple(self): k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) + # Tests for update_cache_with_indices functionality + + def test_basic_update_with_indices(self): + """Test basic update with indices functionality.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update positions 2, 5, 7 + indices = torch.tensor([[2, 5, 7]], dtype=torch.int64) + start_pos = 0 # start_pos is ignored when indices are provided + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_single_index_update(self): + """Test updating a single position with indices.""" + self._reset() + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + + # Update only position 4 + indices = torch.tensor([[4]], dtype=torch.int64) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_sparse_indices(self): + """Test updating non-contiguous positions.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update positions 1, 4, 8 (sparse, non-contiguous) + indices = torch.tensor([[1, 4, 8]], dtype=torch.int64) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_out_of_order_indices(self): + """Test updating positions in a non-sequential order.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update positions in reverse order: 8, 5, 2 + indices = torch.tensor([[8, 5, 2]], dtype=torch.int64) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_indices_exceeding_cache_size(self): + """Test behavior when indices exceed the cache size.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + + # Try to update positions 5, 9, 15 (where 15 is out of bounds) + indices = torch.tensor([[5, 9, 15]], dtype=torch.int64) + start_pos = 0 + + @run_in_subprocess + def run_and_catch(k, k_cache, start_pos, indices): + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + + exception_raised = False + try: + run_and_catch(k, self.quantized_k_cache, start_pos, indices) + except Exception: + exception_raised = True + self.assertTrue(exception_raised) + + def test_negative_indices(self): + """Test behavior with negative indices.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + + # Try to update with negative indices + indices = torch.tensor([[5, -1, 8]], dtype=torch.int64) + start_pos = 0 + + @run_in_subprocess + def run_and_catch(k, k_cache, start_pos, indices): + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + + exception_raised = False + try: + run_and_catch(k, self.quantized_k_cache, start_pos, indices) + except Exception: + exception_raised = True + self.assertTrue(exception_raised) + + def test_duplicate_indices(self): + """Test behavior when the same position is updated multiple times.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update with duplicate indices - the last value should be used + indices = torch.tensor([[3, 5, 3]], dtype=torch.int64) + start_pos = 0 + + # For our reference implementation, we need to handle this case specially + k_cache = self.quantized_k_cache.clone() + v_cache = self.quantized_v_cache.clone() + k_scales_cache = self.k_scales_cache.clone() + v_scales_cache = self.v_scales_cache.clone() + k_zero_points_cache = self.k_zero_points_cache.clone() + v_zero_points_cache = self.v_zero_points_cache.clone() + + # Update using custom op + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices( + k_scales, k_scales_cache, start_pos, indices + ) + torch.ops.llama.update_cache_with_indices( + k_zero_points, k_zero_points_cache, start_pos, indices + ) + torch.ops.llama.update_cache_with_indices(v, v_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices( + v_scales, v_scales_cache, start_pos, indices + ) + torch.ops.llama.update_cache_with_indices( + v_zero_points, v_zero_points_cache, start_pos, indices + ) + + # Position 3 should have the value from the last update (index 2 in the sequence) + self.assertTrue(torch.allclose(k_cache[0, 3], k[0, 2])) + self.assertTrue(torch.allclose(v_cache[0, 3], v[0, 2])) + self.assertTrue(torch.allclose(k_scales_cache[0, 3], k_scales[0, 2])) + self.assertTrue(torch.allclose(v_scales_cache[0, 3], v_scales[0, 2])) + self.assertTrue(torch.allclose(k_zero_points_cache[0, 3], k_zero_points[0, 2])) + self.assertTrue(torch.allclose(v_zero_points_cache[0, 3], v_zero_points[0, 2])) + + # Position 5 should have the value from index 1 + self.assertTrue(torch.allclose(k_cache[0, 5], k[0, 1])) + self.assertTrue(torch.allclose(v_cache[0, 5], v[0, 1])) + + def test_batched_update_with_indices(self): + """Test updating with indices in a batched setting.""" + self.batch_size = 2 + self._reset() + k = torch.randint(0, 50, (self.batch_size, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((self.batch_size, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint( + 0, 20, (self.batch_size, 3, 8, 1), dtype=torch.int64 + ) + + # Different indices for each batch + indices = torch.tensor( + [[1, 4, 7], [2, 5, 8]], # indices for batch 0 # indices for batch 1 + dtype=torch.int64, + ) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_different_seq_lengths_per_batch(self): + """Test updating with different sequence lengths per batch using padding.""" + self.batch_size = 2 + self._reset() + + # Create inputs with 3 tokens + k = torch.randint(0, 50, (self.batch_size, 3, 8, 4), dtype=torch.int8) + + # Batch 0: update 3 positions, Batch 1: update only 2 positions (use -1 as padding) + indices = torch.tensor( + [ + [1, 3, 5], # 3 valid indices for batch 0 + [2, 4, -1], # 2 valid indices for batch 1, with -1 as padding + ], + dtype=torch.int64, + ) + start_pos = 0 + + @run_in_subprocess + def run_and_catch(k, k_cache, start_pos, indices): + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + + exception_raised = False + try: + run_and_catch(k, self.quantized_k_cache, start_pos, indices) + except Exception: + exception_raised = True + self.assertTrue(exception_raised) + def test_update_kv_cache_large_update(self): self._reset() k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8)