From 286c258442da7eac10c441cfaef11f0ec8256ba2 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 1 May 2025 10:19:09 -0700 Subject: [PATCH] [Executorch][llm] Add ring buffer based kv cache and mask calculation to MHA Leveraging previous work now we allow MHA to have ring buffer cache. If ring buffer cache is used then we query the mask from kv cache and use that for sdpa instead of using precalculated mask. In this process we had to adjsut ring buffer implementation to allow keeping the context of full sliding window. See code for comment. Differential Revision: [D73891425](https://our.internmc.facebook.com/intern/diff/D73891425/) [ghstack-poisoned] --- examples/models/llama/attention.py | 93 +++++-- .../llama/source_transformation/sdpa.py | 72 +++--- .../test_quantized_sdpa.py | 18 +- .../test_sdpa_with_quantized_kv_cache.py | 4 +- examples/models/llama/tests/TARGETS | 11 + .../models/llama/tests/test_ring_attention.py | 226 ++++++++++++++++++ .../models/llama/tests/test_ring_kv_cache.py | 85 ++++--- 7 files changed, 411 insertions(+), 98 deletions(-) create mode 100644 examples/models/llama/tests/test_ring_attention.py diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 68636c6528f..862b43bc3f5 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -123,14 +123,12 @@ def __init__( head_dim: int, n_rep: int, max_context_len: int, - enable_dynamic_shape: bool, ): super().__init__() self.dim = dim self.head_dim = head_dim self.n_rep = n_rep self.max_context_len = max_context_len - self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -142,21 +140,12 @@ def forward( seqlen, mask: torch.Tensor, ) -> torch.Tensor: - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) - else: - attn_mask = mask[None, None, input_pos] # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) @@ -236,21 +225,79 @@ def __init__( enable_dynamic_shape: bool, dtype=torch.float32, ): + self.window_size = max_context_length + """ + Reason why we want the kv cache size to be twice the context length: + Sliding window attention without ringbuffer + pos 0 1 2 3 4 5 6 7 8 9 10 + 0 x 0 0 0 0 0 0 0 0 0 0 + 1 x x 0 0 0 0 0 0 0 0 0 + 2 x x x 0 0 0 0 0 0 0 0 + 3 x x x x 0 0 0 0 0 0 0 + 4 0 x x x x 0 0 0 0 0 0 + 5 0 0 x x x x 0 0 0 0 0 + 6 0 0 0 x x x x 0 0 0 0 + 7 0 0 0 0 x x x x 0 0 0 + 8 0 0 0 0 0 x x x x 0 0 + 9 0 0 0 0 0 0 x x x x 0 + 10 0 0 0 0 0 0 0 x x x x + + So when doing attention for pos = 5 and seq_len = 4 our attention + mask would be + 5 0 0 x x x x 0 0 0 0 0 + 6 0 0 0 x x x x 0 0 0 0 + 7 0 0 0 0 x x x x 0 0 0 + 8 0 0 0 0 0 x x x x 0 0 + Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4. + This is how training is done. + + Now lets consider ring kv cache of size 4. When we are at pos = 5 + before updating the kv cache, state of the kv cache would be + [4 1 2 3]. That is we evicted token at pos = 0 out. Now during + attention calculation at pos = 5 seq len = 4, we will update cache and + new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend + to itself. Not 2, 3 and 4 as you would have during training. + So not having kept 2, 3 and 4 in cache means we will have divergent behavior. + Worst case of this would have been when update it equal to the length of + the cache. like in our case pos = 5 seq len = 4. + Thus we need to have a cache that is larger. How much larger, as much as + the sliding window size. So twice the max_context_length. + How would that have helped. Lets see. At pos = 5 our cache would have + [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have + [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the + current step still has access to [pos - sliding_window_size, pos] tokens. + + To make sure we dont over attend, i.e. we dont have pos = 5 + to attend to pos = 1, mask calculaton has to account for the sliding window + size. + """ super().__init__( max_batch_size, - max_context_length, + max_context_length * 2, n_heads, head_dim, enable_dynamic_shape, dtype, ) - self.cache_positions_manager = CachePositionsManager(max_context_length) + self.cache_positions_manager = CachePositionsManager(self.max_context_length) + self.is_ring_buffer = True + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + cache_positions = self.cache_positions_manager.cache_positions + delta = pos_q - cache_positions + attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) @@ -285,6 +332,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.dim = args.dim self.attention_qkv_bias = args.attention_qkv_bias self.use_qk_norm = args.use_qk_norm + self.use_causal_mask = None + self.enable_dynamic_shape = args.enable_dynamic_shape if self.use_qk_norm: q_norm_dim = self.head_dim @@ -330,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): head_dim=self.head_dim, n_rep=self.n_rep, max_context_len=self.max_context_len, - enable_dynamic_shape=args.enable_dynamic_shape, ) def forward( @@ -363,8 +411,21 @@ def forward( if self.use_kv_cache: assert input_pos is not None + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = self.mask.narrow(0, start_pos, seq_length) + else: + attn_mask = self.mask[None, None, input_pos] k, v = self.kv_cache.update(input_pos, k, v) - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) + if getattr(self.kv_cache, "is_ring_buffer", False): + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + input_pos[0].item(), seqlen + ) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) return self.wo(output), None # grouped multiquery attention: expand out keys and values diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 1bc54198fba..28714a792e1 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,15 +22,11 @@ class SDPACustom(torch.nn.Module): def __init__( self, dim: int, - max_context_len, - enable_dynamic_shape, use_attention_mask: bool = False, ): super().__init__() self.dim = dim - self.max_context_len = max_context_len self.use_attention_mask = use_attention_mask - self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -42,16 +38,6 @@ def forward( seqlen, mask, ): - if self.use_attention_mask: - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - mask = mask.narrow(0, start_pos, seq_length) - else: - mask = mask[input_pos] - q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -96,8 +82,6 @@ def _replace_sdpa_with_custom_op( name, SDPACustom( child.dim, - child.max_context_len, - child.enable_dynamic_shape, use_attention_mask=use_attention_mask, ), ) @@ -133,12 +117,15 @@ class QuantizedSDPA(torch.nn.Module): zero points, we need to pass kv_cache to SDPA. """ - def __init__(self, dim: int, kv_cache: QuantizedKVCache): + def __init__( + self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False + ): super().__init__() self.dim = dim self.quantized_dtype = torch.int8 self.float_dtype = torch.float32 self.kv_cache = kv_cache + self.use_attention_mask = use_attention_mask def forward( self, @@ -176,22 +163,40 @@ def forward( v_scale_fp32 = self.kv_cache.v_cache_scales start_pos = input_pos[0].item() - output = torch.ops.llama.custom_quantized_sdpa( - q_quantized, - k_quantized, - v_quantized, - start_pos, - None, - 0, - True, - None, - q_zero_point_int8, - q_scale_fp32, - k_zero_point_int8, - k_scale_fp32, - v_zero_point_int8, - v_scale_fp32, - ) + if self.use_attention_mask: + output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + mask, + 0, + False, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) + else: + output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + None, + 0, + True, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) return output.view(bsz, seqlen, self.dim) @@ -201,6 +206,7 @@ def _update_attention_module_with_quantized_sdpa( ): sdpa = getattr(module, "SDPA", None) assert sdpa is not None + # TODO: add support for SDPA with attention mask # pyre-ignore setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010 diff --git a/examples/models/llama/source_transformation/test_quantized_sdpa.py b/examples/models/llama/source_transformation/test_quantized_sdpa.py index 242f3a0876d..4297221919e 100644 --- a/examples/models/llama/source_transformation/test_quantized_sdpa.py +++ b/examples/models/llama/source_transformation/test_quantized_sdpa.py @@ -31,7 +31,7 @@ def __init__( self.dim = dim self.head_dim = head_dim self.n_rep = n_rep - self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape) + self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len) self.kv_cache = None def forward(self, x, freqs_cos, freqs_sin, **kwargs): @@ -159,15 +159,9 @@ def test_forward_functionality(self): k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v) # Run the forward pass with the quantized SDPA - try: - output = model.attention.SDPA( - input_pos, q, k_quantized, v_quantized, bsz, seqlen, None - ) + output = model.attention.SDPA( + input_pos, q, k_quantized, v_quantized, bsz, seqlen, None + ) - # Verify the output shape - self.assertEqual(output.shape, (bsz, seqlen, self.dim)) - except Exception: - # If the forward pass fails, it might be due to missing custom ops - self.skipTest( - "Custom ops not available, skipping forward functionality test" - ) + # Verify the output shape + self.assertEqual(output.shape, (bsz, seqlen, self.dim)) diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index e5e278f8ce8..b2c93d7d93d 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False): self.seq_len = 3 self._init_cache() q, k_val, v_val = self._init_kv() - self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True) - self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True) + self.float_sdpa = SDPACustom(self.dim) + self.quantized_sdpa = SDPACustom(self.dim) k, v = self.custom_kv_cache.update(input_pos, k_val, v_val) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val) diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 09ca02868ed..0d52cfa19d3 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -49,3 +49,14 @@ python_unittest( "//executorch/examples/models/llama:llama_transformer", ], ) + +python_unittest( + name = "test_ring_attention", + srcs = [ + "test_ring_attention.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) diff --git a/examples/models/llama/tests/test_ring_attention.py b/examples/models/llama/tests/test_ring_attention.py new file mode 100644 index 00000000000..a3f1bfd95ba --- /dev/null +++ b/examples/models/llama/tests/test_ring_attention.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.examples.models.llama.attention import AttentionMHA, RingKVCache +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope + + +class TestRingAttention(unittest.TestCase): + def setUp(self): + # Common test parameters + self.batch_size = 1 + self.seq_len = 1 # Single token processing + self.dim = 64 + self.n_heads = 4 + self.n_kv_heads = 4 + self.head_dim = 16 + self.max_context_len = 16 + self.sliding_window = 8 + self.dtype = torch.float32 + self.device = "cpu" + + def _create_baseline_attention(self, seq_len: int): + """Create baseline attention with regular KV cache.""" + # Create model args + self.args = ModelArgs( + dim=self.dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + max_batch_size=self.batch_size, + max_context_len=self.max_context_len, + use_kv_cache=True, + enable_dynamic_shape=True, + ) + + # Create RoPE instance + self.rope = Rope(self.args) + + attention = AttentionMHA(self.args, layer_id=0, rope=self.rope) + attention.mask = self._create_sliding_window_mask( + seq_len, self.max_context_len, self.sliding_window + ) + + return attention + + def _create_ring_attention(self, attention): + """Create attention with ring buffer KV cache.""" + assert self.sliding_window is not None + # Create RoPE instance + self.rope = Rope(self.args) + baseline_attention = copy.deepcopy(attention) + + # Replace the KV cache with a ring buffer KV cache + baseline_attention.kv_cache = RingKVCache( + self.args.max_batch_size, + self.sliding_window, + self.n_kv_heads, + self.head_dim, + self.args.enable_dynamic_shape, + self.dtype, + ) + return baseline_attention + + def _create_sliding_window_mask(self, seq_len, context_len, window_size): + """Create a sliding window mask for the baseline.""" + mask = torch.full((seq_len, context_len), float("-inf"), dtype=self.dtype) + for i in range(seq_len): + pos = i + # Allow attention to window_size previous positions + start_idx = max(0, pos - window_size + 1) + mask[i, start_idx : pos + 1] = 0 + return mask + + def test_single_token_processing(self): + """Test that ring buffer and baseline produce the same output for single token processing.""" + seq_len = 10 + self.sliding_window = 4 + baseline_attn = self._create_baseline_attention(seq_len) + ring_attn = self._create_ring_attention(baseline_attn) + + # Process tokens one by one + for pos in range(seq_len): + # Create input tensor for a single token + x = torch.randn((self.batch_size, 1, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, 1) + + # Process with baseline attention + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Check that outputs are the same + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + + def test_sliding_window_attention(self): + """Test that ring buffer with sliding window size produces the same output as baseline with sliding window mask.""" + self.sliding_window = 4 + self.max_context_len = 16 + + seq_len = 10 + # Create baseline attention with full context length + baseline_attn = self._create_baseline_attention(seq_len) + + # Create ring attention with sliding window size + ring_attn = self._create_ring_attention(baseline_attn) + + # Process tokens one by one + for pos in range(seq_len): + # Create input tensor for a single token + x = torch.randn((self.batch_size, 1, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, 1) + + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Check that outputs are the same + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + + def test_ring_buffer_wrapping(self): + """Test that ring buffer correctly wraps around and maintains correct attention patterns.""" + self.sliding_window = 3 + self.max_context_len = 15 + + # Create baseline attention with full context length + baseline_attn = self._create_baseline_attention(self.max_context_len) + + # Create ring attention with sliding window size + ring_attn = self._create_ring_attention(baseline_attn) + + # Process enough tokens to cause wrapping + seq_len = 1 + for pos in range(8): + # Create input tensor for a single token + x = torch.randn((self.batch_size, seq_len, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + + # After processing 8 tokens with window size 4, the ring buffer should have wrapped around + # Check the cache positions to verify wrapping + cache_positions = ring_attn.kv_cache.cache_positions_manager.cache_positions + + # The cache positions should contain the most recent 4 positions (4, 5, 6, 7) + # mapped to the ring buffer indices + expected_positions = torch.tensor([6, 7, 2, 3, 4, 5], dtype=torch.long) + + self.assertTrue( + torch.all(cache_positions == expected_positions), + f"Expected positions {expected_positions}, got {cache_positions}", + ) + + def test_large_context_with_sliding_window(self): + """Test with a large context length and compare baseline with sliding window to ring buffer.""" + # Use a larger context length and sliding window for this test + self.max_context_len = 64 + self.sliding_window = 8 + + token_lens = [8, 1, 3, 2, 1, 1, 1, 1, 7, 1, 5, 1, 1, 1, 4, 1, 1, 2, 1, 1] + seq_len = sum(token_lens) + # Create baseline attention with full context length + baseline_attn = self._create_baseline_attention(seq_len) + + # Create ring attention with sliding window size + ring_attn = self._create_ring_attention(baseline_attn) + + pos = 0 + for token_len in token_lens: + # Create input tensor for a single token + x = torch.randn((self.batch_size, token_len, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, token_len) + + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Check that outputs are the same + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos} with max difference {(baseline_out - ring_out).abs().max()}", + ) + pos += token_len diff --git a/examples/models/llama/tests/test_ring_kv_cache.py b/examples/models/llama/tests/test_ring_kv_cache.py index dd9971fa010..0a10e39cf16 100644 --- a/examples/models/llama/tests/test_ring_kv_cache.py +++ b/examples/models/llama/tests/test_ring_kv_cache.py @@ -61,7 +61,8 @@ def test_basic_update(self): # Check that cache_positions was updated correctly expected_positions = torch.tensor( - [0, 1, 2, -1, -1, -1, -1, -1], dtype=torch.long + [0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.long, ) self.assertTrue( torch.all( @@ -81,8 +82,8 @@ def test_ring_buffer_wrapping(self): ) # Create input tensors for first update - input_pos = torch.tensor([6], dtype=torch.long) - seq_len = 4 # This will wrap around from position 6 to positions 6, 7, 0, 1 + input_pos = torch.tensor([14], dtype=torch.long) + seq_len = 4 # This will wrap around from position 14 to positions 14, 15, 0, 1 k_val = ( torch.ones( (self.max_batch_size, self.n_heads, seq_len, self.head_dim), @@ -102,8 +103,8 @@ def test_ring_buffer_wrapping(self): k_out, v_out = cache.update(input_pos, k_val, v_val) # Check that the cache was updated correctly with wrapping - # Positions 6, 7 should be updated - for i in range(6, 8): + # Positions 14, 15 should be updated + for i in range(14, 16): self.assertTrue(torch.all(k_out[:, :, i] == 3.0)) self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) @@ -113,17 +114,19 @@ def test_ring_buffer_wrapping(self): self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) # The rest should still be zeros - for i in range(2, 6): + for i in range(2, 14): self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) # Check that cache_positions was updated correctly - # Note that positions 2, 3, 4, 5 are 0 instead of -1 because in actual ring + # Note that positions 2-13 are 0 instead of -1 because in actual ring # updates those positions would have been updated. - # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4, 5) + # But CachePositionsManager thinks they are updated because start_pos > (2-13) # As a result it does not fill them with -1 and instead uses original values # which is 0, the value cache_position buffer is initialized with. - expected_positions = torch.tensor([8, 9, 0, 0, 0, 0, 6, 7], dtype=torch.long) + expected_positions = torch.tensor( + [16, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 15], dtype=torch.long + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -198,7 +201,10 @@ def test_multiple_updates(self): self.assertTrue(torch.all(v_out2[:, :, i] == 0.0)) # Check that cache_positions was updated correctly - expected_positions = torch.tensor([0, 1, 2, 3, 4, -1, -1, -1], dtype=torch.long) + expected_positions = torch.tensor( + [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.long, + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -206,7 +212,7 @@ def test_multiple_updates(self): ) # Third update with wrapping - input_pos3 = torch.tensor([6], dtype=torch.long) + input_pos3 = torch.tensor([14], dtype=torch.long) seq_len3 = 4 k_val3 = ( torch.ones( @@ -236,17 +242,21 @@ def test_multiple_updates(self): self.assertTrue(torch.all(k_out3[:, :, i] == 7.0)) self.assertTrue(torch.all(v_out3[:, :, i] == 8.0)) - # Position 5 should still be zero - self.assertTrue(torch.all(k_out3[:, :, 5] == 0.0)) - self.assertTrue(torch.all(v_out3[:, :, 5] == 0.0)) + # Positions 5-13 should still be zero + for i in range(5, 14): + self.assertTrue(torch.all(k_out3[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 0.0)) - # Positions 6, 7 should have values from the third update - for i in range(6, 8): + # Positions 14, 15 should have values from the third update + for i in range(14, 16): self.assertTrue(torch.all(k_out3[:, :, i] == 9.0)) self.assertTrue(torch.all(v_out3[:, :, i] == 10.0)) # Check that cache_positions was updated correctly - expected_positions = torch.tensor([8, 9, 2, 3, 4, -1, 6, 7], dtype=torch.long) + expected_positions = torch.tensor( + [16, 17, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, 15], + dtype=torch.long, + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -296,7 +306,8 @@ def test_edge_case_input_pos_zero(self): # Check that cache_positions was updated correctly expected_positions = torch.tensor( - [0, -1, -1, -1, -1, -1, -1, -1], dtype=torch.long + [0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.long, ) self.assertTrue( torch.all( @@ -316,8 +327,10 @@ def test_edge_case_exceeding_context_length(self): ) # Create input tensors - input_pos = torch.tensor([5], dtype=torch.long) - seq_len = 5 # This will wrap around from position 5 to positions 5, 6, 7, 0, 1 + input_pos = torch.tensor([13], dtype=torch.long) + seq_len = ( + 5 # This will wrap around from position 13 to positions 13, 14, 15, 0, 1 + ) k_val = ( torch.ones( (self.max_batch_size, self.n_heads, seq_len, self.head_dim), @@ -336,8 +349,8 @@ def test_edge_case_exceeding_context_length(self): # Update the cache k_out, v_out = cache.update(input_pos, k_val, v_val) - # Check that positions 5, 6, 7 were updated - for i in range(5, 8): + # Check that positions 13, 14, 15 were updated + for i in range(13, 16): self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) @@ -346,18 +359,20 @@ def test_edge_case_exceeding_context_length(self): self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) - # Check that positions 2, 3, 4 are still zeros - for i in range(2, 5): + # Check that positions 2-12 are still zeros + for i in range(2, 13): self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) # Check that cache_positions was updated correctly - # Note that positions 2, 3, 4 are 0 instead of -1 because in actual ring + # Note that positions 2-12 are 0 instead of -1 because in actual ring # updates those positions would have been updated. - # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4) + # But CachePositionsManager thinks they are updated because start_pos > (2-12) # As a result it does not fill them with -1 and instead uses original values # which is 0, the value cache_position buffer is initialized with. - expected_positions = torch.tensor([8, 9, 0, 0, 0, 5, 6, 7], dtype=torch.long) + expected_positions = torch.tensor( + [16, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 15], dtype=torch.long + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -375,7 +390,7 @@ def test_original_indices_tracking(self): self.dtype, ) - # First update at position 10 (will be mapped to position 2 in the ring buffer) + # First update at position 10 (will be mapped to position 10 in the ring buffer) input_pos = torch.tensor([10], dtype=torch.long) seq_len = 4 k_val = torch.ones( @@ -392,14 +407,14 @@ def test_original_indices_tracking(self): # Check that cache_positions correctly tracks the original indices # For input_pos=10 and seq_len=4, the original indices should be 10, 11, 12, 13 - # These map to positions 2, 3, 4, 5 in the ring buffer (since max_context_length=8) - # Note that positions 0, 1, 6 and 7 are 0 instead of -1 because in actual ring + # These map to positions 10, 11, 12, 13 in the ring buffer (since max_context_length=8 but buffer size is 16) + # Note that positions 0-9 are 0 because in actual ring # updates those positions would have been updated for start_pos = 0. - # So CachePositionsManager thinks they are updated because start_pos > (0, 1, 6, 7) + # So CachePositionsManager thinks they are updated because start_pos > (0-9) # As a result it does not fill them with -1 and instead uses original values # which is 0, the value cache_position buffer is initialized with. expected_positions = torch.tensor( - [0, 0, 10, 11, 12, 13, 0, 0], dtype=torch.long + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 12, 13, -1, -1], dtype=torch.long ) self.assertTrue( torch.all( @@ -407,7 +422,7 @@ def test_original_indices_tracking(self): ) ) - # Second update at position 14 (will be mapped to position 6 in the ring buffer) + # Second update at position 14 (will be mapped to position 14 in the ring buffer) input_pos = torch.tensor([14], dtype=torch.long) seq_len = 3 k_val = torch.ones( @@ -424,9 +439,9 @@ def test_original_indices_tracking(self): # Check that cache_positions correctly tracks the original indices # For input_pos=14 and seq_len=3, the original indices should be 14, 15, 16 - # These map to positions 6, 7, 0 in the ring buffer + # These map to positions 14, 15, 0 in the ring buffer expected_positions = torch.tensor( - [16, 0, 10, 11, 12, 13, 14, 15], dtype=torch.long + [16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 12, 13, 14, 15], dtype=torch.long ) self.assertTrue( torch.all(