diff --git a/test/test_pallas.py b/test/test_pallas.py index 0836ab9387a7..8173e3b2713c 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -87,8 +87,9 @@ def _pagedattention_generate_qkv( q = torch.randn(batch_size, query_len, num_heads, head_dim, dtype=dtype) return q, k_pages, v_pages, page_indices - def _round_up_closest_multiple_of(self, x, base): - return (x + base - 1) // base * base + def _ceil_div(self, a, b): + assert b != 0 + return (a + b - 1) // b def _ragged_pagedattention_generate_qkv( self, @@ -97,64 +98,50 @@ def _ragged_pagedattention_generate_qkv( head_dim, page_size, num_pages, - dtype=torch.float32, - num_queries_per_block=None, - pad_num_q_tokens=False, + dtype, + *, + num_kv_pages_per_block=None, + max_num_batched_tokens=None, + max_num_seqs=16, ): - num_seqs = len(seq_lens) - # Make sure the q_len is no longer than the kv_len. For example, - # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because - # the 3rd sequence has q_len(506) > kv_len(463). - for i in range(num_seqs): - cur_q_len = seq_lens[i][0] - cur_kv_len = seq_lens[i][1] - assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" - - query_lens = [seq_len[0] for seq_len in seq_lens] - actual_num_q_tokens = sum(query_lens) - num_q_tokens = self._round_up_closest_multiple_of( - actual_num_q_tokens, - num_queries_per_block) if pad_num_q_tokens else actual_num_q_tokens - kv_lens = torch.tensor([seq_len[1] for seq_len in seq_lens], - dtype=torch.int32) - num_q_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." - queries = torch.randn((num_q_tokens, num_q_heads, head_dim), dtype=dtype) - k_pages = torch.randn((num_kv_heads, num_pages, page_size, head_dim), + cu_q_lens = [0] + kv_lens = [] + for q_len, kv_len in seq_lens: + assert q_len <= kv_len + cu_q_lens.append(cu_q_lens[-1] + q_len) + kv_lens.append(kv_len) + + if max_num_batched_tokens is None: + max_num_batched_tokens = cu_q_lens[-1] + else: + max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens) + if max_num_seqs is None: + max_num_seqs = len(seq_lens) + else: + max_num_seqs = max(len(seq_lens), max_num_seqs) + max_kv_len = max(kv_lens) + pages_per_seq = self._ceil_div(max_kv_len, page_size) + pages_per_seq = ( + self._ceil_div(pages_per_seq, num_kv_pages_per_block) * + num_kv_pages_per_block) + + num_q_heads, num_kv_heads = num_heads + cu_q_lens = torch.tensor(cu_q_lens, dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32) + cu_q_lens = torch.nn.functional.pad( + cu_q_lens, (0, max_num_seqs + 1 - cu_q_lens.shape[0]), "constant", 0) + kv_lens = torch.nn.functional.pad(kv_lens, + (0, max_num_seqs - kv_lens.shape[0]), + "constant", 0) + q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim), + dtype=dtype) + k_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim), dtype=dtype) - v_pages = torch.randn((num_kv_heads, num_pages, page_size, head_dim), + v_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim), dtype=dtype) - - # Create a kv_lens: i32[num_tokens] - kv_lens_with_paddings = [0] * num_q_tokens - for i in range(num_seqs): - kv_lens_with_paddings[i] = kv_lens[i] - kv_lens_ = torch.tensor(kv_lens_with_paddings, dtype=torch.int32) - - # Create a page_indices i32[num_tokens, pages_per_sequence] - max_kv_len = max([seq_len[1] for seq_len in seq_lens]) - max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size - - # The reason why we need to pad max_num_pages_per_seq is that - # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 - max_num_pages_per_seq = 2**int(np.ceil(np.log2(max_num_pages_per_seq))) - - # The assert below mimics the reality that each page get a unique index. - # But for testing, the assert could be omitted. - # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" page_indices = torch.randint( - 0, num_pages, (num_q_tokens, max_num_pages_per_seq), dtype=torch.int32) - - # Create a cu_q_lens i32[num_tokens + 1] - q_lens_with_paddings = [0] * num_q_tokens - for i in range(num_seqs): - q_lens_with_paddings[i] = query_lens[i] - cu_q_lens = torch.cumsum( - torch.tensor([0] + q_lens_with_paddings, dtype=torch.int32), - dim=0, - dtype=torch.int32) - return queries, k_pages, v_pages, page_indices, cu_q_lens, kv_lens_ + 0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32) + return q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tpu_custom_call_pallas_add(self): @@ -648,7 +635,7 @@ def test_paged_attention_wrapper(self): "This test only works on TPUv4+.") def test_ragged_paged_attention_wrapper_without_dynamo(self): from torch_xla.experimental.custom_kernel import ragged_paged_attention - from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention seq_lens = [ (1, 1328), @@ -663,18 +650,25 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): (1, 17), (99, 123) ] # last 3 physical q blocks [(q_len, kv_len),...] - num_heads = (4, 4) + num_heads = (32, 8) head_dim = 128 dtype = torch.float32 page_size = 16 num_pages = 32768 num_seqs = len(seq_lens) - num_kv_pages_per_block = 128 + num_kv_pages_per_block = 16 num_queries_per_block = 8 - block_kv_size = 256 q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv( - seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype) + seq_lens, + num_heads, + head_dim, + page_size, + num_pages, + dtype, + num_kv_pages_per_block=num_kv_pages_per_block, + max_num_batched_tokens=1024, + max_num_seqs=16) q_xla = q.to("xla") k_pages_xla = k_pages.to("xla") @@ -693,7 +687,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): num_seqs=num_seqs, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, - use_kernel=True) + use_kernel=True)[:cu_q_lens[num_seqs]] nonkernel_output = ragged_paged_attention( q_xla, @@ -726,7 +720,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): num_seqs=num_seqs, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, - )[1])) + )[:cu_q_lens[num_seqs]])) self.assertTrue( torch.allclose( @@ -745,19 +739,25 @@ def _verify_ragged_paged_attention_with_dynamo( dtype, num_kv_pages_per_block, num_queries_per_block, - pad_num_q_tokens=False, + pad_tokens_and_seqs=False, sm_scale=1.0, ): num_seqs = len(seq_lens) + max_num_batched_tokens = None + max_num_seqs = None + if pad_tokens_and_seqs: + max_num_batched_tokens = 1024 + max_num_seqs = 16 q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv( seq_lens, num_heads, head_dim, page_size, num_pages, - dtype=dtype, - num_queries_per_block=num_queries_per_block, - pad_num_q_tokens=pad_num_q_tokens) + dtype, + num_kv_pages_per_block=num_kv_pages_per_block, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) q_xla = q.to("xla") k_pages_xla = k_pages.to("xla") @@ -766,29 +766,7 @@ def _verify_ragged_paged_attention_with_dynamo( page_indices_xla = page_indices.to("xla") cu_q_lens_xla = cu_q_lens.to("xla") - def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, - page_indices, cu_q_lens, num_seqs, - num_kv_pages_per_block, - num_queries_per_block, use_kernel, - sm_scale): - return torch.ops.xla.ragged_paged_attention( - q, - k_pages, - v_pages, - kv_lens, - page_indices, - cu_q_lens, - num_seqs, - num_kv_pages_per_block, - num_queries_per_block, - use_kernel=use_kernel, - sm_scale=sm_scale, - ) - - compiled_paged_attention = torch.compile( - ragged_paged_attention_wrapper, backend="openxla") - - kernel_output = compiled_paged_attention( + kernel_output = torch.ops.xla.ragged_paged_attention( q_xla, k_pages_xla, v_pages_xla, @@ -800,9 +778,9 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, num_queries_per_block=num_queries_per_block, use_kernel=True, sm_scale=sm_scale, - ) + )[:cu_q_lens[num_seqs]] - nonkernel_output = compiled_paged_attention( + nonkernel_output = torch.ops.xla.ragged_paged_attention( q_xla, k_pages_xla, v_pages_xla, @@ -828,7 +806,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32) - from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention jax_kernel_output = torch.from_numpy( np.array( jax_ragged_paged_attention( @@ -842,34 +820,19 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, sm_scale=sm_scale, - )[1])) + )[:cu_q_lens[num_seqs]])) jax_kernel_output_cpu = jax_kernel_output.cpu() - if pad_num_q_tokens: - actual_num_q_tokens = cu_q_lens[num_seqs] - self.assertTrue( - torch.allclose( - kernel_output_cpu[:actual_num_q_tokens], - nonkernel_output_cpu[:actual_num_q_tokens], - atol=2e-2, - rtol=1e-2)) - self.assertTrue( - torch.allclose( - kernel_output_cpu[:actual_num_q_tokens], - jax_kernel_output_cpu[:actual_num_q_tokens], - atol=2e-2, - rtol=1e-2)) - else: - self.assertTrue( - torch.allclose( - kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2)) - self.assertTrue( - torch.allclose( - kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2)) + self.assertTrue( + torch.allclose( + kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2)) + self.assertTrue( + torch.allclose( + kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") - def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self): + def test_ragged_paged_attention_wrapper_no_padding_with_dynamo(self): seq_lens = [ (1, 1328), (5, 18), @@ -883,7 +846,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self): (1, 17), (99, 123) ] # last 3 physical q blocks [(q_len, kv_len),...] - num_heads = (4, 4) + num_heads = (32, 8) head_dim = 128 dtype = torch.float32 page_size = 16 @@ -897,7 +860,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self): page_size, num_pages, dtype, - num_kv_pages_per_block=128, + num_kv_pages_per_block=16, num_queries_per_block=8, sm_scale=sm_scale, ) @@ -908,12 +871,12 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self): ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") - def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo( + def test_ragged_paged_attention_wrapper_with_padding_with_dynamo( self, seq_lens, num_queries_per_block, ): - num_heads = (4, 4) + num_heads = (32, 8) head_dim = 128 dtype = torch.float32 page_size = 16 @@ -927,9 +890,9 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo( page_size, num_pages, dtype, - num_kv_pages_per_block=128, + num_kv_pages_per_block=16, num_queries_per_block=num_queries_per_block, - pad_num_q_tokens=True, + pad_tokens_and_seqs=True, sm_scale=sm_scale, ) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 1870fc7e6eca..04e1db666e8f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -14,6 +14,7 @@ from torch_xla.core.xla_model import XLA_LIB _XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1" +DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) def _shard_map(func, mesh, input_specs, output_specs): @@ -887,97 +888,132 @@ def flash_attention( sm_scale, ab, partition_spec, mesh) -def _ragged_paged_attention_nonkernel( - queries, # [num_tokens, num_q_heads, head_dim] - k_pages, # [num_kv_heads, total_num_pages, page_size, head_dim] - v_pages, # [num_kv_heads, total_num_pages, page_size, head_dim] - kv_lens, # i32[num_tokens] - page_indices, # i32[num_tokens, pages_per_sequence] - cu_q_lens, # i32[num_tokens + 1] - num_seqs, # int - sm_scale, # float -): - _, num_q_heads, head_dim = queries.shape - num_kv_heads, total_num_pages, page_size, _ = k_pages.shape - num_query_per_kv = num_q_heads // num_kv_heads - start_idx = 0 - kv_lens = kv_lens.cpu() - page_indices = page_indices.cpu() +def ceil_div(a, b): + assert b != 0 + return (a + b - 1) // b - outputs: List[torch.Tensor] = [] + +def validate_ragged_paged_attention_inputs( + q, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens, # i32[max_num_seqs] + page_indices, # i32[max_num_seqs, pages_per_seq] + cu_q_lens, # i32[max_num_seqs + 1] + num_seqs, # i32 +): + max_num_batched_tokens, num_q_heads, head_dim = q.shape + _, page_size, num_kv_heads, head_dim_k = k_pages.shape + max_num_seqs, pages_per_seq = page_indices.shape + if k_pages.shape != v_pages.shape: + raise ValueError( + f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.") + if head_dim_k != head_dim: + raise ValueError( + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") + if kv_lens.shape != (max_num_seqs,): + raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") + if cu_q_lens.shape != (max_num_seqs + 1,): + raise ValueError( + f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") + if max_num_seqs > max_num_batched_tokens: + raise ValueError( + f"{max_num_seqs=} must be less or equal to {max_num_batched_tokens=}.") + if (kv_lens.dtype != torch.int32 or page_indices.dtype != torch.int32 or + cu_q_lens.dtype != torch.int32): + raise ValueError( + "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" + f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," + f" {cu_q_lens.dtype=}.") + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + + # Must check below on runtime! + if num_seqs > max_num_seqs: + raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}") + max_kv_len = torch.max(kv_lens) + min_pages_per_seq = ceil_div(max_kv_len, page_size) + if pages_per_seq < min_pages_per_seq: + raise ValueError( + f"{pages_per_seq=} must be greater or equal to" + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") + if cu_q_lens[num_seqs] > max_num_batched_tokens: + raise ValueError( + f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to" + f" {max_num_batched_tokens=}.") for i in range(num_seqs): - cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] - q = queries[start_idx:start_idx + - cur_q_len] # [cur_q_len, num_q_heads, head_dim] - cur_kv_len = kv_lens[i] - num_pages = (cur_kv_len + page_size - 1) // page_size - page_indices_to_use = page_indices[i, :num_pages] - - k = k_pages[:, - page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim] - k = k.permute(1, 2, 0, 3) # [num_pages, page_size, num_kv_heads, head_dim] - k = k.reshape(num_pages * page_size, num_kv_heads, head_dim) - k = k[:cur_kv_len] # [cur_kv_len, num_kv_heads, head_dim] - - v = v_pages[:, - page_indices_to_use, :, :] # [num_kv_heads, num_pages, page_size, head_dim] - v = v.permute(1, 2, 0, 3) # [num_pages, page_size, num_kv_heads, head_dim] - v = v.reshape(num_pages * page_size, num_kv_heads, head_dim) - v = v[:cur_kv_len] # [cur_kv_len, num_kv_heads, head_dim] + q_len = cu_q_lens[i + 1] - cu_q_lens[i] + kv_len = kv_lens[i] + if q_len > kv_len: + raise ValueError( + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") - if num_query_per_kv != 1: - # GQA/MQA - k = torch.repeat_interleave( - k, num_query_per_kv, dim=1) # [cur_kv_len, num_query_heads, head_dim] - v = torch.repeat_interleave( - v, num_query_per_kv, dim=1) # [cur_kv_len, num_query_heads, head_dim] - # NOTE: To balance efficiency and performance, we use the original dtype (e.g., bfloat16 or float16) - # for matrix multiplications (i.e., q @ k and attn @ v) while using float32 for softmax. - # However, the kernel doesn't have to strictly follow the dtypes here. - # For example, it can use bfloat16 instead of float32 or vice versa for performance or simplicity. - attn = torch.einsum("qhd,khd->hqk", q, - k) # [num_query_heads, cur_q_len, kv_len] - attn = attn.float() - attn = attn * sm_scale - empty_mask = torch.ones(cur_q_len, cur_kv_len, device=attn.device) - mask = torch.triu(empty_mask, diagonal=cur_kv_len - cur_q_len + 1).bool() - attn.masked_fill_(mask, float("-inf")) +def _ragged_paged_attention_nonkernel( + queries, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens, # i32[max_num_seqs] + page_indices, # i32[max_num_seqs, pages_per_seq] + cu_q_lens, # i32[max_num_seqs + 1] + num_seqs, # i32 + *, + sm_scale=1.0, + mask_value=DEFAULT_MASK_VALUE, +): + _, _, num_kv_heads, head_dim = k_pages.shape + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0 + num_query_per_kv = num_q_heads // num_kv_heads + outputs = [] + for i in range(num_seqs): + q_start = cu_q_lens[i] + q_end = cu_q_lens[i + 1] + q_len = q_end - q_start + kv_len = kv_lens[i] + indices = page_indices[i] + q = queries[q_start:q_end] + k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = torch.repeat_interleave(k, num_query_per_kv, dim=1) + v = torch.repeat_interleave(v, num_query_per_kv, dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k) + attn *= sm_scale + empty_mask = torch.ones(q_len, kv_len, device=attn.device) + mask = torch.triu(empty_mask, diagonal=kv_len - q_len + 1).bool() + attn.masked_fill_(mask, mask_value) attn = torch.softmax( attn, dim=-1).to(v.dtype) # [num_query_heads, cur_q_len, kv_len] out = torch.einsum("hqk,khd->qhd", attn, v) # [cur_q_len, num_query_heads, head_dim] outputs.append(out) - start_idx += cur_q_len - maybe_padded_num_q_tokens = queries.shape[0] - actual_num_tokens = cu_q_lens[num_seqs] - if actual_num_tokens < maybe_padded_num_q_tokens: - num_tokens_diff = maybe_padded_num_q_tokens - actual_num_tokens - outputs.append( - torch.zeros((num_tokens_diff, num_q_heads, head_dim), - dtype=queries[0].dtype, - device=queries.device)) - return torch.cat(outputs, dim=0) # [num_tokens, num_query_heads, head_dim] + return torch.cat(outputs, dim=0) @requires_jax def ragged_paged_attention( - q, # [num_tokens, num_q_heads, head_dim] - k_pages, # [num_kv_heads, total_num_pages, page_size, head_dim] - v_pages, # [num_kv_heads, total_num_pages, page_size, head_dim] - kv_lens, # i32[num_tokens] - page_indices, # i32[num_tokens, pages_per_sequence] - cu_q_lens, # i32[num_tokens + 1] - num_seqs, # int + q, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens, # i32[max_num_seqs] + page_indices, # i32[max_num_seqs, pages_per_seq] + cu_q_lens, # i32[max_num_seqs + 1] + num_seqs, # i32 + *, + sm_scale=1.0, + mask_value=None, num_kv_pages_per_block, num_queries_per_block, + vmem_limit_bytes=None, use_kernel=True, - sm_scale=1.0, - # TODO(jevinjiang, xiowei): add attn_logits_soft_cap. - # attn_logits_soft_cap: float | None = None, -): # [batch_size, query_len, num_heads, head_dim]: - assert len(q.shape) == 3, "q should have 3 dimensions." +): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens, + page_indices, cu_q_lens, num_seqs) if not use_kernel: return _ragged_paged_attention_nonkernel( q, @@ -987,13 +1023,15 @@ def ragged_paged_attention( page_indices, cu_q_lens, num_seqs, - sm_scale, + sm_scale=sm_scale, + mask_value=mask_value, ) # Import JAX within the function such that we don't need to call the jax_import_guard() # in the global scope which could cause problems for xmp.spawn. - from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as ragged_attention, make_sequence_metadata - payload, tensor_args = trace_pallas( + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention + + payload, _ = trace_pallas( ragged_attention, q, k_pages, @@ -1001,84 +1039,44 @@ def ragged_paged_attention( kv_lens, page_indices, cu_q_lens, - num_seqs=num_seqs, + num_seqs, + sm_scale=sm_scale, + mask_value=mask_value, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, - sm_scale=sm_scale, + vmem_limit_bytes=vmem_limit_bytes, static_argnames=[ + "sm_scale", + "mask_value", "num_kv_pages_per_block", "num_queries_per_block", - "mask_value", - "num_seqs", - "sm_scale", + "vmem_limit_bytes", ], ) - sequence_metadata, num_logical_q_tiles = make_sequence_metadata( - cu_q_lens=cu_q_lens.cpu().numpy(), - m=q.shape[0], - tm=num_queries_per_block, - # TODO(jevinjiang, xiowei): pass start_sequence as input. - start_sequence=torch.tensor([0]).cpu().numpy(), - num_sequences=num_seqs, - ) - assert len(sequence_metadata) == 2 - sequence_ids = torch.tensor( - sequence_metadata[0].tolist(), dtype=torch.int32).to("xla") - m_tile_ids = torch.tensor( - sequence_metadata[1].tolist(), dtype=torch.int32).to("xla") - num_q_tiles = torch.tensor( - num_logical_q_tiles.tolist(), dtype=torch.int32).to("xla") - - q_dtype_for_kernel_launch = q.dtype - page_indices_expanded = torch.unsqueeze(page_indices, 1) - buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") - step = torch.zeros((1,), dtype=torch.int32).to("xla") - # The jax checkify in ragged paged attention kernel will insert several scalar refs to both inputs - # (end of prefetch) and outputs (begining of the original outputs). - # TODO(jevinjiang, xiowei): consider seperate checkify from kernel! - s1 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - s2 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - s3 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - s4 = torch.zeros((1, 1), dtype=torch.int32).to("xla") - q = q.permute(1, 0, 2) - MIN_BLOCK_SIZE = 128 - output_shape = torch.Size(list(q.shape[:-1]) + [MIN_BLOCK_SIZE]) - num_q_tiles_1d = torch.tensor([num_logical_q_tiles.tolist()], - dtype=torch.int32).to("xla") - - # TODO(jevinjiang, xiowei): check err returned by checkify! And add tests. - _, _, _, _, output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( + num_q_blks = ceil_div(cu_q_lens[num_seqs], num_queries_per_block) + seq_buf_idx = torch.tensor([0, 0], dtype=torch.int32).to("xla") + num_seqs_ref = torch.tensor([num_seqs], dtype=torch.int32).to("xla") + output = torch_xla._XLAC._xla_tpu_custom_call( [ - num_q_tiles, - sequence_ids, - m_tile_ids, - # Need num_q_tiles_1d to work around a Mosaic internal error. - num_q_tiles_1d, + num_q_blks, kv_lens, + page_indices, cu_q_lens, - buffer_index, - step, - s1, - s2, - s3, - s4, - q.to(q_dtype_for_kernel_launch), + seq_buf_idx, + num_seqs_ref, + q, k_pages, v_pages, - page_indices_expanded, # for the current iteration - page_indices_expanded, # for the next iteration ], payload, [ # output shape - s1.shape, s2.shape, s3.shape, s4.shape, q.shape, output_shape, - output_shape + q.shape ], [ # output dtype - s1.dtype, s2.dtype, s3.dtype, s4.dtype, q_dtype_for_kernel_launch, - torch.float32, torch.float32 + torch.float32, ]) - return output.permute(1, 0, 2) + return output[0].to(q.dtype) def _multi_queries_paged_attention_nonkernel( @@ -1734,27 +1732,58 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor, XLA_LIB.define( - "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, float sm_scale) -> Tensor", + "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, " + "Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, " + "float sm_scale=1.0, float? mask_value=None, int? vmem_limit_bytes=None) -> Tensor", ) @impl(XLA_LIB, "ragged_paged_attention", "XLA") def ragged_paged_attention_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor, - num_seqs: int, num_kv_pages_per_block: int, num_queries_per_block: int, - use_kernel: bool, sm_scale: float): - return ragged_paged_attention(q, k_pages, v_pages, kv_lens, page_indices, - cu_q_lens, num_seqs, num_kv_pages_per_block, - num_queries_per_block, use_kernel, sm_scale) + q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + kv_lens: torch.Tensor, + page_indices: torch.Tensor, + cu_q_lens: torch.Tensor, + num_seqs: int, + num_kv_pages_per_block: int, + num_queries_per_block: int, + use_kernel: bool, + sm_scale: float = 1.0, + mask_value: float | None = None, + vmem_limit_bytes: int | None = None, +): + return ragged_paged_attention( + q, + k_pages, + v_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + use_kernel=use_kernel) @impl(XLA_LIB, "ragged_paged_attention", "CompositeExplicitAutograd") -def ragged_paged_attention_non_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor, - num_seqs: int, num_kv_pages_per_block: int, num_queries_per_block: int, - use_kernel: bool, sm_scale: float): +def ragged_paged_attention_non_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + kv_lens: torch.Tensor, + page_indices: torch.Tensor, + cu_q_lens: torch.Tensor, + num_seqs: int, + num_kv_pages_per_block: int, + num_queries_per_block: int, + use_kernel: bool, + sm_scale: float = 1.0, + mask_value: float | None = None, + vmem_limit_bytes: int | None = None): return non_xla_attetion(q, k_pages, v_pages, "paged") diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py new file mode 100644 index 000000000000..d1db0a79430a --- /dev/null +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -0,0 +1,667 @@ +"""TPU-Friendly Ragged Paged Attention kernel. + +This kernel offers a highly optimized implementation of ragged paged attention, +specifically designed for TPU and compatible with a wide range of model +specifications. It supports mixed prefill and decoding, enhancing throughput +during inference. +""" + +import functools +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + +# TODO(jevinjiang): importing kernel from pltpu ops directly. No need +# to keep duplicated implementations. + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + sem, + page_indices_ref, # i32[num_seqs, pages_per_seq] + offset, # [seq_idx, kv_pages_start] + ): + self._vmem_buf = vmem_buf + seq_id, kv_pages_start = offset + self._async_copies = [ + pltpu.make_async_copy( + pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]], + vmem_buf.at[i], + sem, + ) for i in range(vmem_buf.shape[0]) + ] + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def wait(self): + for async_copy in self._async_copies: + async_copy.wait() + return self._vmem_buf + + +def ref_ragged_paged_attention( + queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[num_seqs] + page_indices: jax.Array, # i32[num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[num_seqs + 1] + num_seqs: int, + *, + sm_scale: float = 1.0, + mask_value: float = DEFAULT_MASK_VALUE, +): + _, _, num_kv_heads, head_dim = k_pages.shape + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0 + num_query_per_kv = num_q_heads // num_kv_heads + outputs = [] + for i in range(num_seqs): + q_start = cu_q_lens[i] + q_end = cu_q_lens[i + 1] + q_len = q_end - q_start + kv_len = kv_lens[i] + indices = page_indices[i] + q = queries[q_start:q_end] + k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) + attn *= sm_scale + q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( + jnp.int32, attn.shape, 1) + kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) + attn += jnp.where(q_span < kv_span, mask_value, 0.0) + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) + outputs.append(out) + + return jnp.concatenate(outputs, axis=0) + + +# Expect to run these checkes during runtime. +def validate_inputs_on_runtime( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32 +): + check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) + max_num_batched_tokens = q.shape[0] + page_size = k_pages.shape[1] + max_num_seqs, pages_per_seq = page_indices.shape + if num_seqs > max_num_seqs: + raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}") + max_kv_len = jnp.max(kv_lens) + min_pages_per_seq = ceil_div(max_kv_len, page_size) + if pages_per_seq < min_pages_per_seq: + raise ValueError( + f"{pages_per_seq=} must be greater or equal to" + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") + if cu_q_lens[num_seqs] > max_num_batched_tokens: + raise ValueError( + f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to" + f" {max_num_batched_tokens=}.") + for i in range(num_seqs): + q_len = cu_q_lens[i + 1] - cu_q_lens[i] + kv_len = kv_lens[i] + if q_len > kv_len: + raise ValueError( + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") + + +# Expect to run these checks during compile time. +def check_inputs_shapes( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] +): + max_num_batched_tokens, num_q_heads, head_dim = q.shape + _, _, num_kv_heads, head_dim_k = k_pages.shape + max_num_seqs, _ = page_indices.shape + if k_pages.shape != v_pages.shape: + raise ValueError( + f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.") + if head_dim_k != head_dim: + raise ValueError( + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") + if kv_lens.shape != (max_num_seqs,): + raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") + if cu_q_lens.shape != (max_num_seqs + 1,): + raise ValueError( + f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") + if max_num_seqs > max_num_batched_tokens: + raise ValueError( + f"{max_num_seqs=} must be less or equal to {max_num_batched_tokens=}.") + if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or + cu_q_lens.dtype != jnp.int32): + raise ValueError( + "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" + f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," + f" {cu_q_lens.dtype=}.") + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + + +def ragged_paged_attention_kernel( + # Prefetch + kv_lens_ref, # [max_num_seqs] + page_indices_ref, # [max_num_seqs, pages_per_seq] + cu_q_lens_ref, # [max_num_seqs + 1] + seq_buf_idx_ref, + # TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs. + num_seqs_ref, + # Input + q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + # Output + o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + # Scratch + k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + sems, # [2, 2] + l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + *, + sm_scale: float, + mask_value: float, +): + num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + num_seqs = num_seqs_ref[0] + _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + num_kv_per_blk = num_kv_pages_per_blk * page_size + num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk + heads_blk_idx, q_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + ) + num_heads_blks = pl.num_programs(0) + init_seq_idx = seq_buf_idx_ref[0] + init_buf_idx = seq_buf_idx_ref[1] + q_len_start = q_blk_idx * num_q_per_blk + q_len_end = q_len_start + num_q_per_blk + + def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, + buf_idx): + offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) + heads_start = heads_blk_idx * num_kv_heads_per_blk + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref.at[:, :, + pl.ds(heads_start, num_kv_heads_per_blk), :], + k_bufs.at[buf_idx], + sems.at[buf_idx, 0], + page_indices_ref, + offset, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref.at[:, :, + pl.ds(heads_start, num_kv_heads_per_blk), :], + v_bufs.at[buf_idx], + sems.at[buf_idx, 1], + page_indices_ref, + offset, + ) + return async_copy_k, async_copy_v + + # TODO(jevinjiang): Add these to Mosaic: + # 1. Support arbitrary strided load/store for any dtype. + # 2. Support arbitrary strided load/store for any last dimension. + def strided_load_kv(ref, start, step): + if ref.dtype == jnp.float32: + return ref[start::step, :] + packing = get_dtype_packing(ref.dtype) + assert ref.dtype == jnp.bfloat16 + assert step % packing == 0 + b_start = start // packing + b_offset = start % packing + b_step = step // packing + b_ref = ref.bitcast(jnp.int32) + b = b_ref[b_start::b_step, :] + bw = 32 // packing + b = jnp.right_shift(b, bw * b_offset) + b = jnp.left_shift(b, bw * (packing - 1)) + return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + + @pl.when(heads_blk_idx + q_blk_idx == 0) + def prefetch_first_kv_blk(): + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + heads_blk_idx, init_seq_idx, 0, init_buf_idx) + async_copy_k.start() + async_copy_v.start() + + def is_cur_q_blk_needed(q_states): + done, cur_seq_idx, _ = q_states + return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + + def compute_with_cur_q_blk(q_states): + done, cur_seq_idx, cur_buf_idx = q_states + q_start = cu_q_lens_ref[cur_seq_idx] + q_end = cu_q_lens_ref[cur_seq_idx + 1] + q_len = q_end - q_start + kv_len = kv_lens_ref[cur_seq_idx] + + def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, + cur_buf_idx): + next_kv_blk_idx = kv_blk_idx + 1 + is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len + next_kv_blk_idx = lax.select( + is_last_kv_blk, + 0, + next_kv_blk_idx, + ) + is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end + next_seq_idx = lax.select( + is_last_kv_blk, + lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx), + cur_seq_idx, + ) + is_last_seq = next_seq_idx == num_seqs + next_seq_idx = lax.select( + is_last_seq, + 0, + next_seq_idx, + ) + next_heads_blk_idx = lax.select( + is_last_seq, + heads_blk_idx + 1, + heads_blk_idx, + ) + next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0) + return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + + def flash_attention( + q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim] + k, # [num_kv_per_blk, head_dim] + v, # [num_kv_per_blk, head_dim] + head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + *, + kv_blk_idx, + ): + assert q.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + head_dim, + ) + assert k.shape == ( + num_kv_per_blk, + head_dim, + ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" + assert v.shape == (num_kv_per_blk, head_dim) + assert head_m_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_l_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_o_ref.shape == ( + num_q_per_blk, + num_q_heads_per_kv_head, + head_dim, + ) + kv_len_start = kv_blk_idx * num_kv_per_blk + + def masked_store(ref, val, start, end, group=1): + iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group + mask = jnp.logical_and(iota >= start, iota < end) + pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + + qk = ( + jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * + sm_scale) + store_start = jnp.maximum(q_start - q_len_start, 0) + store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) + + @pl.when(kv_blk_idx == 0) + def init_scratch_ref(): + masked_store( + head_m_ref, + jnp.full_like(head_m_ref, -jnp.inf), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_l_ref, + jnp.zeros_like(head_l_ref), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_o_ref, + jnp.zeros_like(head_o_ref), + store_start, + store_end, + ) + + row_ids = ((kv_len - q_len) + q_len_start - q_start + + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 0, + ) // num_q_heads_per_kv_head) + col_ids = kv_len_start + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 1, + ) + causal_mask = row_ids < col_ids + qk += jnp.where(causal_mask, mask_value, 0.0) + m_curr = jnp.max(qk, axis=1, keepdims=True) + s_curr = jnp.exp(qk - m_curr) + qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32) + lm_store_shape = head_m_ref.shape + m_curr = jnp.broadcast_to(m_curr, lm_store_shape) + l_curr = jnp.broadcast_to( + s_curr.sum(axis=1, keepdims=True), lm_store_shape) + m_prev = head_m_ref[...] + l_prev = head_l_ref[...] + m_next = jnp.maximum(m_prev, m_curr) + masked_store(head_m_ref, m_next, store_start, store_end, + num_q_heads_per_kv_head) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_alpha = alpha * l_prev + l_next = l_alpha + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + masked_store( + head_l_ref, + l_next_safe, + store_start, + store_end, + num_q_heads_per_kv_head, + ) + + def broadcast_to_shape(arr, shape): + if arr.shape == shape: + return arr + assert len(arr.shape) == len(shape) + assert arr.shape[0] == shape[0] + assert shape[1] % arr.shape[1] == 0 + # no-op concatenation. + return jnp.concatenate([arr for _ in range(shape[1] // arr.shape[1])], + axis=1) + + o_curr = head_o_ref[...].reshape(-1, head_dim) + l_alpha = broadcast_to_shape(l_alpha, qkv.shape) + beta = broadcast_to_shape(beta, qkv.shape) + l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) + out = lax.div( + l_alpha * o_curr + beta * qkv, + l_next_safe, + ).astype(head_o_ref.dtype) + masked_store( + head_o_ref, + out.reshape(head_o_ref.shape), + store_start, + store_end, + ) + + def is_valid_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, _ = kv_states + return kv_blk_idx * num_kv_per_blk < kv_len + + def compute_with_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, cur_buf_idx = kv_states + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( + get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, + cur_buf_idx)) + + @pl.when(next_heads_blk_idx < num_heads_blks) + def prefetch_next_kv_blk(): + # TODO(jevinjiang): reuse the same buffer if it is already prefetched! + # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and + # DMA to fixed size buffer! + next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx) + next_async_copy_k.start() + next_async_copy_v.start() + + cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx) + kv_to_load_shape = ( + num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + head_dim, + ) + k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) + v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) + for kv_head_idx in range(num_kv_heads_per_blk): + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handlig for packed type that can start at + # unaligned position! + q = q_ref[:, + q_head_idx:q_head_idx + num_q_heads_per_kv_head, :].reshape( + -1, head_dim) + k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) + v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + o_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :], + kv_blk_idx=kv_blk_idx, + ) + return kv_blk_idx + 1, next_buf_idx + + _, next_buf_idx = lax.while_loop( + is_valid_kv_blk_in_cur_seq, + compute_with_kv_blk_in_cur_seq, + (0, cur_buf_idx), + ) + next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx) + done = lax.select(q_end < q_len_end, done, 1) + return done, next_seq_idx, next_buf_idx + + _, seq_idx, buf_idx = lax.while_loop( + is_cur_q_blk_needed, + compute_with_cur_q_blk, + (0, init_seq_idx, init_buf_idx), + ) + # Reset seq_idx for next kv_heads_blk if run out of seqs! + seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) + seq_buf_idx_ref[1] = buf_idx + + +def ceil_div(a, b): + assert b != 0 + return (a + b - 1) // b + + +def get_dtype_packing(dtype): + if dtype == jnp.float32: + return 1 + if dtype == jnp.bfloat16: + return 2 + if dtype == jnp.int8: + return 4 + if dtype == jnp.int4: + return 8 + raise ValueError(f"Not implemented: unsupported {dtype=}") + + +def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): + q_packing = get_dtype_packing(q_dtype) + kv_packing = get_dtype_packing(kv_dtype) + + def can_be_xla_fully_tiled(x, packing): + if x % packing != 0: + return False + x //= packing + return x in (1, 2, 4, 8) or x % 8 == 0 + + # TODO(jevinjiang): support unaligned number of heads! + if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + raise ValueError( + f"Not implemented: {num_kv_heads=} can not be XLA fully tiled.") + assert num_q_heads % num_kv_heads == 0 + ratio = num_q_heads // num_kv_heads + # TODO(jevinjiang): we can choose smaller tiling for packed type if large + # second minor tiling is not on. + max_kv_tiling = 8 * kv_packing + min_kv_heads = ( + max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads) + min_q_heads = min_kv_heads * ratio + if can_be_xla_fully_tiled(min_q_heads, q_packing): + return min_q_heads, min_kv_heads + return num_q_heads, num_kv_heads + + +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "mask_value", + "num_kv_pages_per_block", + "num_queries_per_block", + "vmem_limit_bytes", + ], +) +def ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + # TODO(jevinjiang): create a write_to_kv_cache kernel! + k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs, # i32 + *, + sm_scale: float = 1.0, + mask_value: float = DEFAULT_MASK_VALUE, + num_kv_pages_per_block: int = 16, + num_queries_per_block: int = 128, + vmem_limit_bytes: int | None = None, +): + """Ragged paged attention that supports mixed prefill and decode. + + Args: + q: concatenated all sequences' queries. + k_pages: paged K cache. Normally in HBM. + v_pages: paged V cache. Normally in HBM. + kv_lens: padded kv lengths. Only the first num_seqs values are valid. + page_indices: the first index indicates which page to use in the kv cache + for each sequence. Only the first num_seqs values are valid. + cu_q_lens: the cumulative sum of the effective query lengths. Similar to + kv_lens, only the first num_seqs+1 values are valid. + num_seqs: the dynamic number of sequences. + sm_scale: the softmax scale which will be applied to the Q@K^T. + mask_value: mask value for causal mask. + num_kv_pages_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + num_queries_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + vmem_limit_bytes: the vmem limit for the pallas kernel. + + Returns: + The output of the attention. + """ + check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) + _, num_q_heads, head_dim = q.shape + _, page_size, num_kv_heads, _ = k_pages.shape + num_q_per_blk = num_queries_per_block + num_kv_pages_per_blk = num_kv_pages_per_block + num_q_heads_per_kv_head = num_q_heads // num_kv_heads + num_q_blks = ceil_div(cu_q_lens[num_seqs], num_q_per_blk) + num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_kv_heads, q.dtype, k_pages.dtype) + assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 + num_heads_blks = num_q_heads // num_q_heads_per_blk + grid = (num_heads_blks, num_q_blks) + + def q_index_map(heads_blk_idx, q_blk_idx, *_): + return (q_blk_idx, heads_blk_idx, 0) + + q_block_spec = pl.BlockSpec( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + q_index_map, + ) + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + out_specs = q_block_spec + lm_scratch = pltpu.VMEM( + # TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support + # unaligned slicing! + (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), + jnp.float32, + ) + double_buf_scratch = pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_blk, + page_size, + num_kv_heads_per_blk, + head_dim, + ), + k_pages.dtype, + ) + scratch_shapes = [ + double_buf_scratch, # k_bufs + double_buf_scratch, # v_bufs + pltpu.SemaphoreType.DMA((2, 2)), + lm_scratch, # l_ref + lm_scratch, # m_ref + ] + scalar_prefetches = ( + kv_lens, + page_indices, + cu_q_lens, + jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx + # Mosaic only takes dynamic scalar as ref, so we wrap it. + jnp.array([num_seqs], jnp.int32), # num_seqs + ) + kernel = pl.pallas_call( + functools.partial( + ragged_paged_attention_kernel, + sm_scale=sm_scale, + mask_value=mask_value, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=( + "arbitrary", + "arbitrary", + ), + vmem_limit_bytes=vmem_limit_bytes, + ), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + name="ragged_paged_attention_kernel", + ) + # TODO(jevinjiang): Use f32 acc scratch for output! So we only need + # to transfer output with desired dtype back to HBM. + return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype)