diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index c886a062c39..515fd0080fc 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -123,14 +123,12 @@ def __init__( head_dim: int, n_rep: int, max_context_len: int, - enable_dynamic_shape: bool, ): super().__init__() self.dim = dim self.head_dim = head_dim self.n_rep = n_rep self.max_context_len = max_context_len - self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -142,21 +140,12 @@ def forward( seqlen, mask: torch.Tensor, ) -> torch.Tensor: - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) - else: - attn_mask = mask[None, None, input_pos] # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) @@ -236,21 +225,79 @@ def __init__( enable_dynamic_shape: bool, dtype=torch.float32, ): + self.window_size = max_context_length + """ + Reason why we want the kv cache size to be twice the context length: + Sliding window attention without ringbuffer + pos 0 1 2 3 4 5 6 7 8 9 10 + 0 x 0 0 0 0 0 0 0 0 0 0 + 1 x x 0 0 0 0 0 0 0 0 0 + 2 x x x 0 0 0 0 0 0 0 0 + 3 x x x x 0 0 0 0 0 0 0 + 4 0 x x x x 0 0 0 0 0 0 + 5 0 0 x x x x 0 0 0 0 0 + 6 0 0 0 x x x x 0 0 0 0 + 7 0 0 0 0 x x x x 0 0 0 + 8 0 0 0 0 0 x x x x 0 0 + 9 0 0 0 0 0 0 x x x x 0 + 10 0 0 0 0 0 0 0 x x x x + + So when doing attention for pos = 5 and seq_len = 4 our attention + mask would be + 5 0 0 x x x x 0 0 0 0 0 + 6 0 0 0 x x x x 0 0 0 0 + 7 0 0 0 0 x x x x 0 0 0 + 8 0 0 0 0 0 x x x x 0 0 + Thus tok at pos = 5 is able to attend to tokens at pos 2, 3 and 4. + This is how training is done. + + Now lets consider ring kv cache of size 4. When we are at pos = 5 + before updating the kv cache, state of the kv cache would be + [4 1 2 3]. That is we evicted token at pos = 0 out. Now during + attention calculation at pos = 5 seq len = 4, we will update cache and + new pos in the cache would be [8 5 6 7]. So note that 5 can now only attend + to itself. Not 2, 3 and 4 as you would have during training. + So not having kept 2, 3 and 4 in cache means we will have divergent behavior. + Worst case of this would have been when update it equal to the length of + the cache. like in our case pos = 5 seq len = 4. + Thus we need to have a cache that is larger. How much larger, as much as + the sliding window size. So twice the max_context_length. + How would that have helped. Lets see. At pos = 5 our cache would have + [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have + [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the + current step still has access to [pos - sliding_window_size, pos] tokens. + + To make sure we dont over attend, i.e. we dont have pos = 5 + to attend to pos = 1, mask calculaton has to account for the sliding window + size. + """ super().__init__( max_batch_size, - max_context_length, + max_context_length * 2, n_heads, head_dim, enable_dynamic_shape, dtype, ) - self.cache_positions_manager = CachePositionsManager(max_context_length) + self.cache_positions_manager = CachePositionsManager(self.max_context_length) + self.is_ring_buffer = True + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + cache_positions = self.cache_positions_manager.cache_positions + delta = pos_q - cache_positions + attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] seq_len = k_val.size(2) + assert seq_len <= self.k_cache.size( + 2 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" indices = self.cache_positions_manager.calculate_positions_and_update_indices( input_pos, seq_len ) @@ -286,6 +333,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.attention_qkv_bias = args.attention_qkv_bias self.use_qk_norm = args.use_qk_norm self.qk_norm_before_rope = args.qk_norm_before_rope + self.enable_dynamic_shape = args.enable_dynamic_shape if self.use_qk_norm: q_norm_dim = self.head_dim @@ -331,7 +379,6 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): head_dim=self.head_dim, n_rep=self.n_rep, max_context_len=self.max_context_len, - enable_dynamic_shape=args.enable_dynamic_shape, ) def forward( @@ -368,8 +415,22 @@ def forward( if self.use_kv_cache: assert input_pos is not None + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = self.mask.narrow(0, start_pos, seq_length) + else: + # mask is always 2D + attn_mask = self.mask[input_pos] k, v = self.kv_cache.update(input_pos, k, v) - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) + if getattr(self.kv_cache, "is_ring_buffer", False): + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + input_pos[0].item(), seqlen + ) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) return self.wo(output), None # grouped multiquery attention: expand out keys and values diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 1bc54198fba..1fb3d97a9c7 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,15 +22,11 @@ class SDPACustom(torch.nn.Module): def __init__( self, dim: int, - max_context_len, - enable_dynamic_shape, use_attention_mask: bool = False, ): super().__init__() self.dim = dim - self.max_context_len = max_context_len self.use_attention_mask = use_attention_mask - self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -42,16 +38,6 @@ def forward( seqlen, mask, ): - if self.use_attention_mask: - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - mask = mask.narrow(0, start_pos, seq_length) - else: - mask = mask[input_pos] - q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -96,8 +82,6 @@ def _replace_sdpa_with_custom_op( name, SDPACustom( child.dim, - child.max_context_len, - child.enable_dynamic_shape, use_attention_mask=use_attention_mask, ), ) @@ -133,12 +117,15 @@ class QuantizedSDPA(torch.nn.Module): zero points, we need to pass kv_cache to SDPA. """ - def __init__(self, dim: int, kv_cache: QuantizedKVCache): + def __init__( + self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False + ): super().__init__() self.dim = dim self.quantized_dtype = torch.int8 self.float_dtype = torch.float32 self.kv_cache = kv_cache + self.use_attention_mask = use_attention_mask def forward( self, @@ -176,22 +163,40 @@ def forward( v_scale_fp32 = self.kv_cache.v_cache_scales start_pos = input_pos[0].item() - output = torch.ops.llama.custom_quantized_sdpa( - q_quantized, - k_quantized, - v_quantized, - start_pos, - None, - 0, - True, - None, - q_zero_point_int8, - q_scale_fp32, - k_zero_point_int8, - k_scale_fp32, - v_zero_point_int8, - v_scale_fp32, - ) + if self.use_attention_mask: + output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + mask, + 0, + False, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) + else: + output = torch.ops.llama.custom_quantized_sdpa( + q_quantized, + k_quantized, + v_quantized, + start_pos, + None, + 0, + True, + None, + q_zero_point_int8, + q_scale_fp32, + k_zero_point_int8, + k_scale_fp32, + v_zero_point_int8, + v_scale_fp32, + ) return output.view(bsz, seqlen, self.dim) @@ -201,6 +206,7 @@ def _update_attention_module_with_quantized_sdpa( ): sdpa = getattr(module, "SDPA", None) assert sdpa is not None + # TODO: add support for SDPA with attention mask # pyre-ignore setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010 @@ -254,7 +260,8 @@ def forward( seqlen, mask, ): - attn_mask = mask[None, None, input_pos] + # Input mask is slided however it is 2D + attn_mask = mask[None, None] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) @@ -310,7 +317,8 @@ def forward( """ k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) - attn_mask = mask[input_pos] + # Mask is already sliced as needed + attn_mask = mask scale_factor = 1 / math.sqrt(q.size(-1)) attn_weight = q @ k.transpose(-2, -1) * scale_factor @@ -391,7 +399,8 @@ def forward( seqlen, mask, ): - attn_mask = mask[None, None, input_pos] + # Input mask is slided however it is 2D + attn_mask = mask[None, None] if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) diff --git a/examples/models/llama/source_transformation/test_quantized_sdpa.py b/examples/models/llama/source_transformation/test_quantized_sdpa.py index 242f3a0876d..4297221919e 100644 --- a/examples/models/llama/source_transformation/test_quantized_sdpa.py +++ b/examples/models/llama/source_transformation/test_quantized_sdpa.py @@ -31,7 +31,7 @@ def __init__( self.dim = dim self.head_dim = head_dim self.n_rep = n_rep - self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len, enable_dynamic_shape) + self.SDPA = SDPA(dim, head_dim, n_rep, max_context_len) self.kv_cache = None def forward(self, x, freqs_cos, freqs_sin, **kwargs): @@ -159,15 +159,9 @@ def test_forward_functionality(self): k_quantized, v_quantized = model.attention.kv_cache.update(input_pos, k, v) # Run the forward pass with the quantized SDPA - try: - output = model.attention.SDPA( - input_pos, q, k_quantized, v_quantized, bsz, seqlen, None - ) + output = model.attention.SDPA( + input_pos, q, k_quantized, v_quantized, bsz, seqlen, None + ) - # Verify the output shape - self.assertEqual(output.shape, (bsz, seqlen, self.dim)) - except Exception: - # If the forward pass fails, it might be due to missing custom ops - self.skipTest( - "Custom ops not available, skipping forward functionality test" - ) + # Verify the output shape + self.assertEqual(output.shape, (bsz, seqlen, self.dim)) diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index e5e278f8ce8..b2c93d7d93d 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False): self.seq_len = 3 self._init_cache() q, k_val, v_val = self._init_kv() - self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True) - self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True) + self.float_sdpa = SDPACustom(self.dim) + self.quantized_sdpa = SDPACustom(self.dim) k, v = self.custom_kv_cache.update(input_pos, k_val, v_val) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val) diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 09ca02868ed..0d52cfa19d3 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -49,3 +49,14 @@ python_unittest( "//executorch/examples/models/llama:llama_transformer", ], ) + +python_unittest( + name = "test_ring_attention", + srcs = [ + "test_ring_attention.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) diff --git a/examples/models/llama/tests/test_ring_attention.py b/examples/models/llama/tests/test_ring_attention.py new file mode 100644 index 00000000000..064be7f04e0 --- /dev/null +++ b/examples/models/llama/tests/test_ring_attention.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.examples.models.llama.attention import AttentionMHA, RingKVCache +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope +from torch.nn.attention import SDPBackend + + +class TestRingAttention(unittest.TestCase): + def setUp(self): + # Common test parameters + self.batch_size = 1 + self.seq_len = 1 # Single token processing + self.dim = 64 + self.n_heads = 4 + self.n_kv_heads = 4 + self.head_dim = 16 + self.max_context_len = 16 + self.sliding_window = 8 + self.dtype = torch.float32 + self.device = "cpu" + + def _create_baseline_attention(self, seq_len: int): + """Create baseline attention with regular KV cache.""" + # Create model args + self.args = ModelArgs( + dim=self.dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + max_batch_size=self.batch_size, + max_context_len=self.max_context_len, + use_kv_cache=True, + enable_dynamic_shape=True, + ) + + # Create RoPE instance + self.rope = Rope(self.args) + + attention = AttentionMHA(self.args, layer_id=0, rope=self.rope) + attention.mask = self._create_sliding_window_mask( + seq_len, self.max_context_len, self.sliding_window + ) + + return attention + + def _create_ring_attention(self, attention): + """Create attention with ring buffer KV cache.""" + assert self.sliding_window is not None + # Create RoPE instance + self.rope = Rope(self.args) + baseline_attention = copy.deepcopy(attention) + + # Replace the KV cache with a ring buffer KV cache + baseline_attention.kv_cache = RingKVCache( + self.args.max_batch_size, + self.sliding_window, + self.n_kv_heads, + self.head_dim, + self.args.enable_dynamic_shape, + self.dtype, + ) + return baseline_attention + + def _create_sliding_window_mask(self, seq_len, context_len, window_size): + """Create a sliding window mask for the baseline.""" + mask = torch.full((seq_len, context_len), float("-inf"), dtype=self.dtype) + for i in range(seq_len): + pos = i + # Allow attention to window_size previous positions + start_idx = max(0, pos - window_size + 1) + mask[i, start_idx : pos + 1] = 0 + return mask + + def test_single_token_processing(self): + """Test that ring buffer and baseline produce the same output for single token processing.""" + seq_len = 10 + self.sliding_window = 4 + baseline_attn = self._create_baseline_attention(seq_len) + ring_attn = self._create_ring_attention(baseline_attn) + + # Process tokens one by one + with torch.nn.attention.sdpa_kernel( + [SDPBackend.FLASH_ATTENTION] + ), torch.no_grad(): + for pos in range(seq_len): + # Create input tensor for a single token + x = torch.randn((self.batch_size, 1, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, 1) + + # Process with baseline attention + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Check that outputs are the same + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + + def test_sliding_window_attention(self): + """Test that ring buffer with sliding window size produces the same output as baseline with sliding window mask.""" + self.sliding_window = 4 + self.max_context_len = 16 + + seq_len = 10 + # Create baseline attention with full context length + baseline_attn = self._create_baseline_attention(seq_len) + + # Create ring attention with sliding window size + ring_attn = self._create_ring_attention(baseline_attn) + + # Process tokens one by one + with torch.nn.attention.sdpa_kernel( + [SDPBackend.FLASH_ATTENTION] + ), torch.no_grad(): + for pos in range(seq_len): + # Create input tensor for a single token + x = torch.randn((self.batch_size, 1, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, 1) + + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Check that outputs are the same + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + + def test_ring_buffer_wrapping(self): + """Test that ring buffer correctly wraps around and maintains correct attention patterns.""" + self.sliding_window = 3 + self.max_context_len = 15 + + # Create baseline attention with full context length + baseline_attn = self._create_baseline_attention(self.max_context_len) + + # Create ring attention with sliding window size + ring_attn = self._create_ring_attention(baseline_attn) + + # Process enough tokens to cause wrapping + seq_len = 1 + with torch.nn.attention.sdpa_kernel( + [SDPBackend.FLASH_ATTENTION] + ), torch.no_grad(): + for pos in range(8): + # Create input tensor for a single token + x = torch.randn((self.batch_size, seq_len, self.dim), dtype=self.dtype) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seq_len) + + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + + # After processing 8 tokens with window size 4, the ring buffer should have wrapped around + # Check the cache positions to verify wrapping + cache_positions = ring_attn.kv_cache.cache_positions_manager.cache_positions + + # The cache positions should contain the most recent 4 positions (4, 5, 6, 7) + # mapped to the ring buffer indices + expected_positions = torch.tensor([6, 7, 2, 3, 4, 5], dtype=torch.long) + + self.assertTrue( + torch.all(cache_positions == expected_positions), + f"Expected positions {expected_positions}, got {cache_positions}", + ) + + def test_large_context_with_sliding_window(self): + """Test with a large context length and compare baseline with sliding window to ring buffer.""" + # Use a larger context length and sliding window for this test + self.max_context_len = 64 + self.sliding_window = 8 + + token_lens = [8, 1, 3, 2, 1, 1, 1, 1, 7, 1, 5, 1, 1, 1, 4, 1, 1, 2, 1, 1] + seq_len = sum(token_lens) + # Create baseline attention with full context length + baseline_attn = self._create_baseline_attention(seq_len) + + # Create ring attention with sliding window size + ring_attn = self._create_ring_attention(baseline_attn) + + pos = 0 + with torch.nn.attention.sdpa_kernel( + [SDPBackend.FLASH_ATTENTION] + ), torch.no_grad(): + for token_len in token_lens: + # Create input tensor for a single token + x = torch.randn( + (self.batch_size, token_len, self.dim), dtype=self.dtype + ) + input_pos = torch.tensor([pos], dtype=torch.long) + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, token_len) + + baseline_out, _ = baseline_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Process with ring buffer attention + ring_out, _ = ring_attn.forward( + x, freqs_cos, freqs_sin, input_pos=input_pos + ) + + # Check that outputs are the same + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos} with max difference {(baseline_out - ring_out).abs().max()}", + ) + pos += token_len diff --git a/examples/models/llama/tests/test_ring_kv_cache.py b/examples/models/llama/tests/test_ring_kv_cache.py index dd9971fa010..0a10e39cf16 100644 --- a/examples/models/llama/tests/test_ring_kv_cache.py +++ b/examples/models/llama/tests/test_ring_kv_cache.py @@ -61,7 +61,8 @@ def test_basic_update(self): # Check that cache_positions was updated correctly expected_positions = torch.tensor( - [0, 1, 2, -1, -1, -1, -1, -1], dtype=torch.long + [0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.long, ) self.assertTrue( torch.all( @@ -81,8 +82,8 @@ def test_ring_buffer_wrapping(self): ) # Create input tensors for first update - input_pos = torch.tensor([6], dtype=torch.long) - seq_len = 4 # This will wrap around from position 6 to positions 6, 7, 0, 1 + input_pos = torch.tensor([14], dtype=torch.long) + seq_len = 4 # This will wrap around from position 14 to positions 14, 15, 0, 1 k_val = ( torch.ones( (self.max_batch_size, self.n_heads, seq_len, self.head_dim), @@ -102,8 +103,8 @@ def test_ring_buffer_wrapping(self): k_out, v_out = cache.update(input_pos, k_val, v_val) # Check that the cache was updated correctly with wrapping - # Positions 6, 7 should be updated - for i in range(6, 8): + # Positions 14, 15 should be updated + for i in range(14, 16): self.assertTrue(torch.all(k_out[:, :, i] == 3.0)) self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) @@ -113,17 +114,19 @@ def test_ring_buffer_wrapping(self): self.assertTrue(torch.all(v_out[:, :, i] == 4.0)) # The rest should still be zeros - for i in range(2, 6): + for i in range(2, 14): self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) # Check that cache_positions was updated correctly - # Note that positions 2, 3, 4, 5 are 0 instead of -1 because in actual ring + # Note that positions 2-13 are 0 instead of -1 because in actual ring # updates those positions would have been updated. - # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4, 5) + # But CachePositionsManager thinks they are updated because start_pos > (2-13) # As a result it does not fill them with -1 and instead uses original values # which is 0, the value cache_position buffer is initialized with. - expected_positions = torch.tensor([8, 9, 0, 0, 0, 0, 6, 7], dtype=torch.long) + expected_positions = torch.tensor( + [16, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 15], dtype=torch.long + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -198,7 +201,10 @@ def test_multiple_updates(self): self.assertTrue(torch.all(v_out2[:, :, i] == 0.0)) # Check that cache_positions was updated correctly - expected_positions = torch.tensor([0, 1, 2, 3, 4, -1, -1, -1], dtype=torch.long) + expected_positions = torch.tensor( + [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.long, + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -206,7 +212,7 @@ def test_multiple_updates(self): ) # Third update with wrapping - input_pos3 = torch.tensor([6], dtype=torch.long) + input_pos3 = torch.tensor([14], dtype=torch.long) seq_len3 = 4 k_val3 = ( torch.ones( @@ -236,17 +242,21 @@ def test_multiple_updates(self): self.assertTrue(torch.all(k_out3[:, :, i] == 7.0)) self.assertTrue(torch.all(v_out3[:, :, i] == 8.0)) - # Position 5 should still be zero - self.assertTrue(torch.all(k_out3[:, :, 5] == 0.0)) - self.assertTrue(torch.all(v_out3[:, :, 5] == 0.0)) + # Positions 5-13 should still be zero + for i in range(5, 14): + self.assertTrue(torch.all(k_out3[:, :, i] == 0.0)) + self.assertTrue(torch.all(v_out3[:, :, i] == 0.0)) - # Positions 6, 7 should have values from the third update - for i in range(6, 8): + # Positions 14, 15 should have values from the third update + for i in range(14, 16): self.assertTrue(torch.all(k_out3[:, :, i] == 9.0)) self.assertTrue(torch.all(v_out3[:, :, i] == 10.0)) # Check that cache_positions was updated correctly - expected_positions = torch.tensor([8, 9, 2, 3, 4, -1, 6, 7], dtype=torch.long) + expected_positions = torch.tensor( + [16, 17, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, 15], + dtype=torch.long, + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -296,7 +306,8 @@ def test_edge_case_input_pos_zero(self): # Check that cache_positions was updated correctly expected_positions = torch.tensor( - [0, -1, -1, -1, -1, -1, -1, -1], dtype=torch.long + [0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.long, ) self.assertTrue( torch.all( @@ -316,8 +327,10 @@ def test_edge_case_exceeding_context_length(self): ) # Create input tensors - input_pos = torch.tensor([5], dtype=torch.long) - seq_len = 5 # This will wrap around from position 5 to positions 5, 6, 7, 0, 1 + input_pos = torch.tensor([13], dtype=torch.long) + seq_len = ( + 5 # This will wrap around from position 13 to positions 13, 14, 15, 0, 1 + ) k_val = ( torch.ones( (self.max_batch_size, self.n_heads, seq_len, self.head_dim), @@ -336,8 +349,8 @@ def test_edge_case_exceeding_context_length(self): # Update the cache k_out, v_out = cache.update(input_pos, k_val, v_val) - # Check that positions 5, 6, 7 were updated - for i in range(5, 8): + # Check that positions 13, 14, 15 were updated + for i in range(13, 16): self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) @@ -346,18 +359,20 @@ def test_edge_case_exceeding_context_length(self): self.assertTrue(torch.all(k_out[:, :, i] == 13.0)) self.assertTrue(torch.all(v_out[:, :, i] == 14.0)) - # Check that positions 2, 3, 4 are still zeros - for i in range(2, 5): + # Check that positions 2-12 are still zeros + for i in range(2, 13): self.assertTrue(torch.all(k_out[:, :, i] == 0.0)) self.assertTrue(torch.all(v_out[:, :, i] == 0.0)) # Check that cache_positions was updated correctly - # Note that positions 2, 3, 4 are 0 instead of -1 because in actual ring + # Note that positions 2-12 are 0 instead of -1 because in actual ring # updates those positions would have been updated. - # But CachePositionsManager thinks they are updated because start_pos > (2, 3, 4) + # But CachePositionsManager thinks they are updated because start_pos > (2-12) # As a result it does not fill them with -1 and instead uses original values # which is 0, the value cache_position buffer is initialized with. - expected_positions = torch.tensor([8, 9, 0, 0, 0, 5, 6, 7], dtype=torch.long) + expected_positions = torch.tensor( + [16, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 15], dtype=torch.long + ) self.assertTrue( torch.all( cache.cache_positions_manager.cache_positions == expected_positions @@ -375,7 +390,7 @@ def test_original_indices_tracking(self): self.dtype, ) - # First update at position 10 (will be mapped to position 2 in the ring buffer) + # First update at position 10 (will be mapped to position 10 in the ring buffer) input_pos = torch.tensor([10], dtype=torch.long) seq_len = 4 k_val = torch.ones( @@ -392,14 +407,14 @@ def test_original_indices_tracking(self): # Check that cache_positions correctly tracks the original indices # For input_pos=10 and seq_len=4, the original indices should be 10, 11, 12, 13 - # These map to positions 2, 3, 4, 5 in the ring buffer (since max_context_length=8) - # Note that positions 0, 1, 6 and 7 are 0 instead of -1 because in actual ring + # These map to positions 10, 11, 12, 13 in the ring buffer (since max_context_length=8 but buffer size is 16) + # Note that positions 0-9 are 0 because in actual ring # updates those positions would have been updated for start_pos = 0. - # So CachePositionsManager thinks they are updated because start_pos > (0, 1, 6, 7) + # So CachePositionsManager thinks they are updated because start_pos > (0-9) # As a result it does not fill them with -1 and instead uses original values # which is 0, the value cache_position buffer is initialized with. expected_positions = torch.tensor( - [0, 0, 10, 11, 12, 13, 0, 0], dtype=torch.long + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 12, 13, -1, -1], dtype=torch.long ) self.assertTrue( torch.all( @@ -407,7 +422,7 @@ def test_original_indices_tracking(self): ) ) - # Second update at position 14 (will be mapped to position 6 in the ring buffer) + # Second update at position 14 (will be mapped to position 14 in the ring buffer) input_pos = torch.tensor([14], dtype=torch.long) seq_len = 3 k_val = torch.ones( @@ -424,9 +439,9 @@ def test_original_indices_tracking(self): # Check that cache_positions correctly tracks the original indices # For input_pos=14 and seq_len=3, the original indices should be 14, 15, 16 - # These map to positions 6, 7, 0 in the ring buffer + # These map to positions 14, 15, 0 in the ring buffer expected_positions = torch.tensor( - [16, 0, 10, 11, 12, 13, 14, 15], dtype=torch.long + [16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 11, 12, 13, 14, 15], dtype=torch.long ) self.assertTrue( torch.all( diff --git a/examples/models/llama/tests/test_simple_sdpa.py b/examples/models/llama/tests/test_simple_sdpa.py index d60bc30b7d3..dbfa38ac590 100644 --- a/examples/models/llama/tests/test_simple_sdpa.py +++ b/examples/models/llama/tests/test_simple_sdpa.py @@ -35,13 +35,13 @@ def test_simple_sdpa(self): head_dim=head_dim, n_rep=n_rep, max_context_len=max_context_length, - enable_dynamic_shape=False, ) input_pos = torch.tensor([0]) query = torch.randn(1, 1, n_local_heads, head_dim) key = torch.randn(1, 1, n_local_heads, head_dim) value = torch.randn(1, 1, n_local_heads, head_dim) mask = torch.randn(max_context_length, max_context_length) + mask = mask[input_pos] query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2)