-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention #24197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
39c037a
6b171ed
5689bf8
65f33bd
aa888d9
11cdce7
16235d2
b09a187
363a52f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,8 @@ | |
| from vllm.platforms import current_platform | ||
| from vllm.triton_utils import tl, triton | ||
| from vllm.utils import cdiv, is_pin_memory_available | ||
| from vllm.utils.flashinfer import (supports_trtllm_attention, | ||
| from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, | ||
| supports_trtllm_attention, | ||
| use_trtllm_attention) | ||
| from vllm.v1.attention.backends.flash_attn import use_cascade_attention | ||
| # yapf conflicts with isort for this block | ||
|
|
@@ -48,8 +49,89 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class FlashInferBackend(AttentionBackend): | ||
| @triton.jit | ||
| def _trtllm_prefill_attn_kvfp8_dequant( | ||
| kv_cache_ptr, | ||
| block_tables_prefill_ptr, | ||
| block_table_stride, | ||
| mock_kv_cache_ptr, | ||
| k_scale_ptr, | ||
| v_scale_ptr, | ||
| K_CACHE_STRIDE: tl.constexpr, | ||
| KV_CACHE_STRIDE: tl.constexpr, | ||
| ): | ||
| batch_idx = tl.program_id(0).to(tl.int64) | ||
| mock_block_table_idx = tl.program_id(1).to(tl.int64) | ||
| orig_page_num = tl.load(block_tables_prefill_ptr + | ||
| batch_idx * block_table_stride + | ||
| mock_block_table_idx).to(tl.int64) | ||
| if orig_page_num <= 0: | ||
| return | ||
| dequant_dtype = mock_kv_cache_ptr.dtype.element_ty | ||
|
|
||
| # Dequantize K | ||
| k_scale_val = tl.load(k_scale_ptr) | ||
| offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) | ||
| fp8_vals = tl.load(kv_cache_ptr + offset) | ||
| dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val | ||
| mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx | ||
| + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) | ||
| dequantized_vals = dequantized_vals.to(dequant_dtype) | ||
| tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) | ||
|
|
||
| # Dequantize V | ||
| v_scale_val = tl.load(v_scale_ptr) | ||
| offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + | ||
| tl.arange(0, K_CACHE_STRIDE)) | ||
| fp8_vals = tl.load(kv_cache_ptr + offset) | ||
| dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val | ||
| mock_cache_offset = ( | ||
| (batch_idx * block_table_stride + mock_block_table_idx + 1) * | ||
| KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) | ||
| dequantized_vals = dequantized_vals.to(dequant_dtype) | ||
| tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) | ||
|
|
||
|
|
||
| def trtllm_prefill_attn_kvfp8_dequant( | ||
|
||
| kv_cache: torch.Tensor, | ||
| block_tables_prefill: torch.Tensor, | ||
| k_scale: torch.Tensor, | ||
| v_scale: torch.Tensor, | ||
| dequant_dtype: torch.dtype, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| batch_size, num_of_page_per_token = block_tables_prefill.shape | ||
| s = kv_cache.shape | ||
| assert s[1] == 2 | ||
| assert dequant_dtype in (torch.bfloat16, torch.float16) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mgoin revised to passing in q_dtype, and then support bf16 and fp16 and assert when encountering other types |
||
| k_cache_stride = s[2] * s[3] * s[4] | ||
| kv_cache_stride = k_cache_stride * s[1] | ||
| new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) | ||
| # mock kv cache contains just the pages needed by this prefill | ||
| mock_kv_cache = torch.empty(new_s, | ||
| dtype=dequant_dtype, | ||
| device=kv_cache.device) | ||
| # we simply sequentially index the pages needed by this prefill | ||
| mock_block_table = torch.arange( | ||
| start=1, | ||
| end=batch_size * num_of_page_per_token + 1, | ||
| dtype=torch.int32, | ||
| device=block_tables_prefill.device, | ||
| ).reshape(batch_size, num_of_page_per_token) | ||
| grid = (batch_size, num_of_page_per_token) | ||
| _trtllm_prefill_attn_kvfp8_dequant[grid]( | ||
| kv_cache, | ||
| block_tables_prefill, | ||
| num_of_page_per_token, | ||
| mock_kv_cache, | ||
| k_scale, | ||
| v_scale, | ||
| k_cache_stride, | ||
| kv_cache_stride, | ||
| ) | ||
| return mock_kv_cache, mock_block_table | ||
|
|
||
|
|
||
| class FlashInferBackend(AttentionBackend): | ||
| accept_output_buffer: bool = True | ||
|
|
||
| @classmethod | ||
|
|
@@ -122,7 +204,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: | |
|
|
||
| @dataclass | ||
| class FlashInferMetadata: | ||
|
|
||
| num_actual_tokens: int # Number of tokens excluding padding. | ||
|
|
||
| # The data type of the query | ||
|
|
@@ -175,8 +256,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], | |
| self.kv_cache_spec.block_size) | ||
| max_num_reqs = vllm_config.scheduler_config.max_num_seqs | ||
| max_num_pages = max_num_reqs * max_num_pages_per_req | ||
| self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ | ||
| decode_mode() == CUDAGraphMode.FULL | ||
| self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ | ||
| decode_mode() == CUDAGraphMode.FULL) | ||
| if self.enable_cuda_graph: | ||
| # For full cudagraph capture, one `decode_wrapper` for each batch | ||
| # size is needed for FlashInfer. | ||
|
|
@@ -201,7 +282,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], | |
| assert self.kv_cache_spec.dtype == self.model_config.dtype | ||
| self.kv_cache_dtype = self.kv_cache_spec.dtype | ||
|
|
||
| if supports_trtllm_attention()[0]: | ||
| if supports_trtllm_attention()[0] and \ | ||
| not flashinfer_disable_q_quantization(): | ||
| self.q_data_type = self.kv_cache_dtype | ||
| else: | ||
| self.q_data_type = self.model_config.dtype | ||
|
|
@@ -795,11 +877,29 @@ def forward( | |
| assert self.o_sf_scale is None | ||
| out = output[num_decode_tokens:] | ||
|
|
||
| if attn_metadata.q_data_type != FP8_DTYPE \ | ||
| and self.kv_cache_dtype.startswith("fp8"): | ||
| # TRTLLM prefill attention does not support BF16 Q | ||
| # and fp8 kv cache. So to enable prefill attention | ||
| # with fp8 kv cache, we can construct a mock block | ||
| # and mock kv cache with BF16 KV involved in the prefill | ||
| mock_kv_cache, mock_block_table = ( | ||
| trtllm_prefill_attn_kvfp8_dequant( | ||
| kv_cache_permute, | ||
| block_tables_prefill, | ||
| layer._k_scale, | ||
| layer._v_scale, | ||
| attn_metadata.q_data_type, | ||
| )) | ||
| else: | ||
| mock_kv_cache = kv_cache_permute | ||
| mock_block_table = block_tables_prefill | ||
|
Comment on lines
+895
to
+896
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I don't like using the term "mock" here. I realize we need to keep around |
||
|
|
||
| trtllm_batch_context_with_kv_cache( | ||
| query=prefill_query, | ||
| kv_cache=kv_cache_permute, | ||
| kv_cache=mock_kv_cache, | ||
| workspace_buffer=workspace_buffer, | ||
| block_tables=block_tables_prefill, | ||
| block_tables=mock_block_table, | ||
| seq_lens=seq_lens_prefill, | ||
| max_q_len=attn_metadata.max_q_len, | ||
| max_kv_len=attn_metadata.max_seq_len, | ||
|
|
@@ -837,7 +937,7 @@ def forward( | |
| decode_query = decode_query.contiguous() | ||
| workspace_buffer = decode_wrapper._float_workspace_buffer | ||
| block_tables_decode = attn_metadata.\ | ||
| block_table_tensor[:num_decode_tokens] | ||
| block_table_tensor[:num_decode_tokens] | ||
| seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] | ||
|
|
||
| # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we could get rid of this after #26146