-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention #24197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D81623711 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for fp8 kv cache in trtllm prefill attention by introducing a dequantization kernel. The overall approach is sound and the implementation looks good, but there is a critical syntax error in a type hint that will cause a runtime failure.
c3ba521 to
a1e931f
Compare
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Can we create an issue, to have the trtllm kernel support fp8 kv by default, no need to have the additional conversion from us. (for tracking purpose)
|
Btw, fix the pre-commit (linter) please |
|
@houseroad seems like there is a draft PR #23647 that can potentially get rid of the conversion . However, even so, the conversion may stay because it represents different compute precision and may have accuracy implications. |
75c39e2 to
7ae2b84
Compare
|
Yeah, I think it makes sense to have this to land. Fix the linter please. :-) |
|
the failed test is test_cutlass_mla_decode.py, however, my change is not relevant to MLA though (it is paged attention only). Is this an existing testing failure? |
Thanks for the contribution. However, if we agree to always quantize query with KV_cache_dtype=FP8 after #23647, then we will not have opportunities to run these code paths anymore, right? cc @ProExpertProg |
|
it may be good to have a path to do prefill in BF16 Q, FP8 KV, and compute in BF16 (even though it does not have to be the default one) because FP8 QKV + FP8 compute attn does theoretically provide lower precision? i feel it would be good to have this as an option (maybe triggered by some env variable after FP8 QKV becomes the default for prefill attention) for debugging purpose. |
|
This pull request has merge conflicts that must be resolved before it can be |
Head branch was pushed to by a user without write access
94bd126 to
f8f843e
Compare
|
This pull request was exported from Phabricator. Differential Revision: D81623711 |
|
Revised to use env VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION and default to be false |
|
Overall looks good to me, thanks. cc @ProExpertProg @mgoin to check if you have any comments for the env var or something else. The last thing is I don't think we need so much unrelated formatting. It should be good to pass the pre-commit with original state(like my previous PR #23647). Could you try reverting those and running pre-commit locally? |
|
One more: add the env to environment_variables_to_hash Line 1271 in e408272
|
Signed-off-by: Xiaozhu <[email protected]>
Signed-off-by: Xiaozhu <[email protected]>
Signed-off-by: Xiaozhu <[email protected]>
|
Added the new env to env hash and reverted most unrelated lint changes. My local linter still decided to make some changes, but the churn seems to be much smaller now. |
|
The failed CI seems unrelated? |
By the way, do you know when the trtllm-gen release will be available? |
Signed-off-by: Xiaozhu <[email protected]>
| tl.store(mock_kv_cache_ptr + mock_cache_offset, | ||
| dequantized_vals.to(tl.bfloat16)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't hardcode the dtype to bfloat16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin triton does not support template dtype, so it is difficult to support generic types here. Do you have some specific types that you want to support beyond bfloat16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't you just use the dtype of mock_kv_cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the kernel, we need to use triton dtype (tl.bfloat16 or tl.float16) instead of tensor dtype (torch.bfloat16 or torch.float16). I may be wrong but my understanding is that within a triton kernel, we do not get the dtype of a tensor. The only way i am aware is to passing in bool variable as kernel input argument to indicate whether certain type is being used. We cannot pass a torch.dtype as a kernel input argument.
In my latest revision, i changed the code to support both bf16 and fp16
| (batch_idx * block_table_stride + mock_block_table_idx + 1) * | ||
| KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) | ||
| tl.store(mock_kv_cache_ptr + mock_cache_offset, | ||
| dequantized_vals.to(tl.bfloat16)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) | ||
| # mock kv cache contains just the pages needed by this prefill | ||
| mock_kv_cache = torch.empty(new_s, | ||
| dtype=torch.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| mock_kv_cache = kv_cache_permute | ||
| mock_block_table = block_tables_prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't like using the term "mock" here. I realize we need to keep around kv_cache_permute for decode so we can't just reuse the name, but it is misleading to pass into the kernel
Signed-off-by: Xiaozhu <[email protected]>
| batch_size, num_of_page_per_token = block_tables_prefill.shape | ||
| s = kv_cache.shape | ||
| assert s[1] == 2 | ||
| assert dequant_dtype in (torch.bfloat16, torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin revised to passing in q_dtype, and then support bf16 and fp16 and assert when encountering other types
| if dequant_to_bf16: | ||
| dequantized_vals = dequantized_vals.to(tl.bfloat16) | ||
| else: | ||
| dequantized_vals = dequantized_vals.to(tl.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for this external dtype branching logic, just get the element type from the output directly i.e.
dequantized_vals = dequantized_vals.to(mock_kv_cache_ptr.type.element_ty)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin Revised. Thanks!
Signed-off-by: Xiaozhu <[email protected]>
Signed-off-by: Xiaozhu <[email protected]>
|
We can create some issue to track moving this logic to trtllm? |
…ention (vllm-project#24197) Signed-off-by: Xiaozhu <[email protected]>
…ention (vllm-project#24197) Signed-off-by: Xiaozhu <[email protected]>
…ention (vllm-project#24197) Signed-off-by: Xiaozhu <[email protected]>
…ention (vllm-project#24197) Signed-off-by: Xiaozhu <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
| @functools.cache | ||
| def flashinfer_disable_q_quantization() -> bool: | ||
| """Cache result which only depends on the environment""" | ||
| return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we could get rid of this after #26146
…ention (vllm-project#24197) Signed-off-by: Xiaozhu <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Summary:
trtllm prefill attention does not support bf16 Q + fp8 kv. To enable fp8 kv using trtllm, we can add a dequant kernel before pretill attention, which creates a mock kv cache and mock block table, which include only tokens needed by prefills.
Since for prefill, the involved tokens are small compared to what KV cache can hold overall, this dequant kernel brings relatively small overhead and allows us to enable fp8 kv.
Test Plan:
AIME
Low
[{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-low_temp1.0_20250903_113216', 'metric': 0.5}]
Medium:
[{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250903_111346', 'metric': 0.7416666666666667}]
High:
[{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250903_131208', 'metric': 0.9}]
Rollback Plan:
Differential Revision: D81623711