From dcf7c5a8c03491fa567911208e6ea772c7dc6889 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 1 May 2025 10:19:05 -0700 Subject: [PATCH] [Executorch][llm] Add support for ring kv cache and ring attention Introduced CachePositionManager to keep track of what is the position for each slot in ring kv cache. This is used to generate mask. Differential Revision: [D73891427](https://our.internmc.facebook.com/intern/diff/D73891427/) [ghstack-poisoned] --- examples/models/llama/attention.py | 107 ++++ examples/models/llama/tests/TARGETS | 11 + .../models/llama/tests/test_ring_kv_cache.py | 477 ++++++++++++++++++ 3 files changed, 595 insertions(+) create mode 100644 examples/models/llama/tests/test_ring_kv_cache.py diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 54f738ba737..68636c6528f 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Dict, Optional, Tuple, Type, TypedDict import torch @@ -160,6 +161,112 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +class CacheUpdateStrategy(Enum): + RING_BUFFER = "RingBuffer" + INVALID = "Invalid" + + +class CachePositionsManager(nn.Module): + def __init__( + self, + max_context_length: int, + cache_update_strategy: CacheUpdateStrategy = CacheUpdateStrategy.RING_BUFFER, + ): + super().__init__() + assert ( + cache_update_strategy == CacheUpdateStrategy.RING_BUFFER + ), "Only RingBuffer is supported" + self.max_context_length = max_context_length + self.register_buffer( + "cache_positions", + torch.zeros((self.max_context_length), dtype=torch.long, device="cpu"), + ) + + def calculate_positions_and_update_indices(self, input_pos: torch.Tensor, seq_len): + """ + Calculate indices, into k_cache, v_cache, where to put k_val tensor. + Given the input_pos and length of k_val at sequence dim, the input pos may + have to wrap around if it is smaller than the cache capacity. + If it is larger than the cache capacity then just pick the last + self.max_context_length entries. + + Additionally: + Update the cache positions buffer with the new indices. + Given the cache positions in sequence dim, indicated by indices, + we can just update cache_positions buffer using orig_indices. + For example + Given cache capacity of 4 and update of length 3 with start_pos = 2 + will have following values + indices = [2, 3, 0] + orig_indices = [2, 3, 4] + So cache_positions after the update will be [4, 1, 2, 3] + Note cache_positions[1] = 1 that is from previous write to the cache. + The corner case here is cache positions before cache rolls over. + For example when start_pos = 0 and update is of length 2, then we have + filled positions 0 and 1 in the buffer, while the rest are invalid. In this case + we have + indices = [0, 1] + orig_indices = [0, 1] + But if we have cache_positins = [0, 1, 0, 0] that is not valid. Hence we have + to make sure that invalid positions have a sentinel value of - 1. + """ + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + orig_indices = torch.arange(seq_len, dtype=torch.long) + start_pos + indices = orig_indices % self.max_context_length + + full_t = torch.full((self.max_context_length,), -1, dtype=torch.long) + arange_tensor = torch.arange(self.max_context_length, dtype=torch.long) + cache_positions = torch.where( + arange_tensor < start_pos, self.cache_positions, full_t + ) + self.cache_positions.copy_(cache_positions) + self.cache_positions.index_copy_(0, indices, orig_indices) + + return indices + + +class RingKVCache(KVCache): + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype=torch.float32, + ): + super().__init__( + max_batch_size, + max_context_length, + n_heads, + head_dim, + enable_dynamic_shape, + dtype, + ) + self.cache_positions_manager = CachePositionsManager(max_context_length) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [S], k_val: [B, H, S, D] + seq_len = k_val.size(2) + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + + self.k_cache.index_copy_(2, indices, k_val) + self.v_cache.index_copy_(2, indices, v_val) + else: + self.k_cache[:, :, indices] = k_val + self.v_cache[:, :, indices] = v_val + + return self.k_cache, self.v_cache + + @register_attention("mha") class AttentionMHA(Attention): def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 0efaa9635c4..09ca02868ed 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -38,3 +38,14 @@ python_unittest( "//executorch/examples/models/llama:static_attention", ], ) + +python_unittest( + name = "test_ring_kv_cache", + srcs = [ + "test_ring_kv_cache.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) diff --git a/examples/models/llama/tests/test_ring_kv_cache.py b/examples/models/llama/tests/test_ring_kv_cache.py new file mode 100644 index 00000000000..dd9971fa010 --- /dev/null +++ b/examples/models/llama/tests/test_ring_kv_cache.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.attention import RingKVCache + + +class TestRingKVCache(unittest.TestCase): + def setUp(self): + # Common test parameters + self.max_batch_size = 2 + self.max_context_length = 8 + self.n_heads = 4 + self.head_dim = 16 + self.enable_dynamic_shape = True + self.dtype = torch.float32 + + def test_basic_update(self): + """Test basic update functionality of RingKVCache.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 3 + k_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 2 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that the cache was updated correctly + for i in range(seq_len): + self.assertTrue(torch.all(k_out[:, :, i] == 1.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 2.0)) + + # Check that the rest of the cache is still zeros + for i in range(seq_len, self.max_context_length): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor( + [0, 1, 2, -1, -1, -1, -1, -1], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_ring_buffer_wrapping(self): + """Test that the ring buffer wraps around correctly.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors for first update + input_pos = torch.tensor([6], dtype=torch.long) + seq_len = 4 # This will wrap around from position 6 to positions 6, 7, 0, 1 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 3 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 4 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that the cache was updated correctly with wrapping + # Positions 6, 7 should be updated + for i in range(6, 8): + self.assertTrue(torch.all(k_out[:, :, i] == 3.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) + + # Positions 0, 1 should also be updated due to wrapping + for i in range(0, 2): + self.assertTrue(torch.all(k_out[:, :, i] == 3.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) + + # The rest should still be zeros + for i in range(2, 6): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + # Note that positions 2, 3, 4, 5 are 0 instead of -1 because in actual ring + # updates those positions would have been updated. + # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4, 5) + # As a result it does not fill them with -1 and instead uses original values + # which is 0, the value cache_position buffer is initialized with. + expected_positions = torch.tensor([8, 9, 0, 0, 0, 0, 6, 7], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_multiple_updates(self): + """Test multiple updates to the cache.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # First update + input_pos1 = torch.tensor([0], dtype=torch.long) + seq_len1 = 2 + k_val1 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len1, self.head_dim), + dtype=self.dtype, + ) + * 5 + ) + v_val1 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len1, self.head_dim), + dtype=self.dtype, + ) + * 6 + ) + + _, _ = cache.update(input_pos1, k_val1, v_val1) + + # Second update + input_pos2 = torch.tensor([2], dtype=torch.long) + seq_len2 = 3 + k_val2 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len2, self.head_dim), + dtype=self.dtype, + ) + * 7 + ) + v_val2 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len2, self.head_dim), + dtype=self.dtype, + ) + * 8 + ) + + k_out2, v_out2 = cache.update(input_pos2, k_val2, v_val2) + + # Check that the cache was updated correctly after both updates + # First update (positions 0, 1) + for i in range(0, 2): + self.assertTrue(torch.all(k_out2[:, :, i] == 5.0)) + self.assertTrue(torch.all(v_out2[:, :, i] == 6.0)) + + # Second update (positions 2, 3, 4) + for i in range(2, 5): + self.assertTrue(torch.all(k_out2[:, :, i] == 7.0)) + self.assertTrue(torch.all(v_out2[:, :, i] == 8.0)) + + # The rest should still be zeros + for i in range(5, 8): + self.assertTrue(torch.all(k_out2[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out2[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor([0, 1, 2, 3, 4, -1, -1, -1], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + # Third update with wrapping + input_pos3 = torch.tensor([6], dtype=torch.long) + seq_len3 = 4 + k_val3 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len3, self.head_dim), + dtype=self.dtype, + ) + * 9 + ) + v_val3 = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len3, self.head_dim), + dtype=self.dtype, + ) + * 10 + ) + + k_out3, v_out3 = cache.update(input_pos3, k_val3, v_val3) + + # Check final state after third update with wrapping + # Positions 0, 1 should now have values from the third update (due to wrapping) + for i in range(0, 2): + self.assertTrue(torch.all(k_out3[:, :, i] == 9.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 10.0)) + + # Positions 2, 3, 4 should still have values from the second update + for i in range(2, 5): + self.assertTrue(torch.all(k_out3[:, :, i] == 7.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 8.0)) + + # Position 5 should still be zero + self.assertTrue(torch.all(k_out3[:, :, 5] == 0.0)) + self.assertTrue(torch.all(v_out3[:, :, 5] == 0.0)) + + # Positions 6, 7 should have values from the third update + for i in range(6, 8): + self.assertTrue(torch.all(k_out3[:, :, i] == 9.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 10.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor([8, 9, 2, 3, 4, -1, 6, 7], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_edge_case_input_pos_zero(self): + """Test the edge case where input_pos is 0.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 1 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 11 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 12 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that position 0 was updated + self.assertTrue(torch.all(k_out[:, :, 0] == 11.0)) + self.assertTrue(torch.all(v_out[:, :, 0] == 12.0)) + + # Check that the rest of the cache is still zeros + for i in range(1, self.max_context_length): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + expected_positions = torch.tensor( + [0, -1, -1, -1, -1, -1, -1, -1], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_edge_case_exceeding_context_length(self): + """Test the edge case where input_pos + seq_len > max_context_length.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([5], dtype=torch.long) + seq_len = 5 # This will wrap around from position 5 to positions 5, 6, 7, 0, 1 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 13 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 14 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that positions 5, 6, 7 were updated + for i in range(5, 8): + self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) + + # Check that positions 0, 1 were also updated due to wrapping + for i in range(0, 2): + self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) + + # Check that positions 2, 3, 4 are still zeros + for i in range(2, 5): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) + + # Check that cache_positions was updated correctly + # Note that positions 2, 3, 4 are 0 instead of -1 because in actual ring + # updates those positions would have been updated. + # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4) + # As a result it does not fill them with -1 and instead uses original values + # which is 0, the value cache_position buffer is initialized with. + expected_positions = torch.tensor([8, 9, 0, 0, 0, 5, 6, 7], dtype=torch.long) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_original_indices_tracking(self): + """Test that the original indices are tracked correctly in cache_positions.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + self.enable_dynamic_shape, + self.dtype, + ) + + # First update at position 10 (will be mapped to position 2 in the ring buffer) + input_pos = torch.tensor([10], dtype=torch.long) + seq_len = 4 + k_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + v_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + + # Update the cache + cache.update(input_pos, k_val, v_val) + + # Check that cache_positions correctly tracks the original indices + # For input_pos=10 and seq_len=4, the original indices should be 10, 11, 12, 13 + # These map to positions 2, 3, 4, 5 in the ring buffer (since max_context_length=8) + # Note that positions 0, 1, 6 and 7 are 0 instead of -1 because in actual ring + # updates those positions would have been updated for start_pos = 0. + # So CachePositionsManager thinks they are updated because start_pos > (0, 1, 6, 7) + # As a result it does not fill them with -1 and instead uses original values + # which is 0, the value cache_position buffer is initialized with. + expected_positions = torch.tensor( + [0, 0, 10, 11, 12, 13, 0, 0], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + # Second update at position 14 (will be mapped to position 6 in the ring buffer) + input_pos = torch.tensor([14], dtype=torch.long) + seq_len = 3 + k_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + v_val = torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + + # Update the cache + cache.update(input_pos, k_val, v_val) + + # Check that cache_positions correctly tracks the original indices + # For input_pos=14 and seq_len=3, the original indices should be 14, 15, 16 + # These map to positions 6, 7, 0 in the ring buffer + expected_positions = torch.tensor( + [16, 0, 10, 11, 12, 13, 14, 15], dtype=torch.long + ) + self.assertTrue( + torch.all( + cache.cache_positions_manager.cache_positions == expected_positions + ) + ) + + def test_non_dynamic_shape(self): + """Test RingKVCache with enable_dynamic_shape=False.""" + cache = RingKVCache( + self.max_batch_size, + self.max_context_length, + self.n_heads, + self.head_dim, + enable_dynamic_shape=False, + dtype=self.dtype, + ) + + # Create input tensors + input_pos = torch.tensor([0], dtype=torch.long) + seq_len = 3 + k_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 15 + ) + v_val = ( + torch.ones( + (self.max_batch_size, self.n_heads, seq_len, self.head_dim), + dtype=self.dtype, + ) + * 16 + ) + + # Update the cache + k_out, v_out = cache.update(input_pos, k_val, v_val) + + # Check that the cache was updated correctly + for i in range(seq_len): + self.assertTrue(torch.all(k_out[:, :, i] == 15.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 16.0)) + + # Check that the rest of the cache is still zeros + for i in range(seq_len, self.max_context_length): + self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out[:, :, i] == 0.0))