From 1c511eb4bad5cc8216a6a9ce1102fefa5dcce51a Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 3 Mar 2025 20:07:52 +0000 Subject: [PATCH 1/3] Resolve Alex's remaining PR Signed-off-by: Xiongfei Wei --- vllm/v1/attention/backends/pallas.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index bf4a05daf2d5..f48b30dcbe89 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -11,6 +11,7 @@ AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +# These are the 2 tunable parameters of the paged attention Pallas kernel. NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 @@ -164,6 +165,8 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, + # use_kernel switches between using kernel or the reference implementation + # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890) use_kernel=False, ) From e28828d858582be71aa00812a282452a6085388e Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 3 Mar 2025 20:16:46 +0000 Subject: [PATCH 2/3] fix linter Signed-off-by: Xiongfei Wei --- vllm/v1/attention/backends/pallas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index f48b30dcbe89..194c4cfce84c 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -165,8 +165,8 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - # use_kernel switches between using kernel or the reference implementation - # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890) + # use_kernel switches between using kernel or the reference implementation + # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). use_kernel=False, ) From 1ce8b587f2122a376e2a71526558309301885f98 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 3 Mar 2025 13:44:07 -0700 Subject: [PATCH 3/3] Fix comment --- vllm/v1/attention/backends/pallas.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 194c4cfce84c..543e8487e28b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -155,6 +155,9 @@ def forward( write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale + # use_kernel switches between using kernel or reference implementation + # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). + use_kernel = False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -165,9 +168,7 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - # use_kernel switches between using kernel or the reference implementation - # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). - use_kernel=False, + use_kernel=use_kernel, ) return output.reshape(num_tokens, hidden_size)