Skip to content

Conversation

@mxz297
Copy link
Contributor

@mxz297 mxz297 commented Sep 3, 2025

Summary:
trtllm prefill attention does not support bf16 Q + fp8 kv. To enable fp8 kv using trtllm, we can add a dequant kernel before pretill attention, which creates a mock kv cache and mock block table, which include only tokens needed by prefills.

Since for prefill, the involved tokens are small compared to what KV cache can hold overall, this dequant kernel brings relatively small overhead and allows us to enable fp8 kv.

Test Plan:
AIME

Low
[{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-low_temp1.0_20250903_113216', 'metric': 0.5}]

Medium:
[{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250903_111346', 'metric': 0.7416666666666667}]

High:
[{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250903_131208', 'metric': 0.9}]

Rollback Plan:

Differential Revision: D81623711

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D81623711

@mergify mergify bot added the v1 label Sep 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for fp8 kv cache in trtllm prefill attention by introducing a dequantization kernel. The overall approach is sound and the implementation looks good, but there is a critical syntax error in a type hint that will cause a runtime failure.

@mxz297 mxz297 changed the title support for fp8 kv cache for trtllm prefill attention [flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention Sep 3, 2025
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Can we create an issue, to have the trtllm kernel support fp8 kv by default, no need to have the additional conversion from us. (for tracking purpose)

@houseroad
Copy link
Collaborator

Btw, fix the pre-commit (linter) please

@mxz297
Copy link
Contributor Author

mxz297 commented Sep 4, 2025

@houseroad seems like there is a draft PR #23647 that can potentially get rid of the conversion . However, even so, the conversion may stay because it represents different compute precision and may have accuracy implications.

@houseroad
Copy link
Collaborator

Yeah, I think it makes sense to have this to land. Fix the linter please. :-)

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 4, 2025
@houseroad houseroad enabled auto-merge (squash) September 4, 2025 17:56
@mxz297
Copy link
Contributor Author

mxz297 commented Sep 4, 2025

the failed test is test_cutlass_mla_decode.py, however, my change is not relevant to MLA though (it is paged attention only). Is this an existing testing failure?

@elvischenv
Copy link
Contributor

@houseroad seems like there is a draft PR #23647 that can potentially get rid of the conversion . However, even so, the conversion may stay because it represents different compute precision and may have accuracy implications.

Thanks for the contribution. However, if we agree to always quantize query with KV_cache_dtype=FP8 after #23647, then we will not have opportunities to run these code paths anymore, right? cc @ProExpertProg

@mxz297
Copy link
Contributor Author

mxz297 commented Sep 5, 2025

it may be good to have a path to do prefill in BF16 Q, FP8 KV, and compute in BF16 (even though it does not have to be the default one) because FP8 QKV + FP8 compute attn does theoretically provide lower precision?

i feel it would be good to have this as an option (maybe triggered by some env variable after FP8 QKV becomes the default for prefill attention) for debugging purpose.

@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mxz297.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
@facebook-github-bot
Copy link

@mxz297 has imported this pull request. If you are a Meta employee, you can view this in D81623711.

auto-merge was automatically disabled September 9, 2025 22:39

Head branch was pushed to by a user without write access

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D81623711

@mergify mergify bot removed the needs-rebase label Sep 9, 2025
@mxz297
Copy link
Contributor Author

mxz297 commented Sep 10, 2025

Revised to use env VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION and default to be false

@elvischenv
Copy link
Contributor

Overall looks good to me, thanks. cc @ProExpertProg @mgoin to check if you have any comments for the env var or something else.

The last thing is I don't think we need so much unrelated formatting. It should be good to pass the pre-commit with original state(like my previous PR #23647). Could you try reverting those and running pre-commit locally?

@elvischenv
Copy link
Contributor

One more: add the env to environment_variables_to_hash

vllm/vllm/envs.py

Line 1271 in e408272

environment_variables_to_hash = [

@mxz297
Copy link
Contributor Author

mxz297 commented Sep 10, 2025

Added the new env to env hash and reverted most unrelated lint changes. My local linter still decided to make some changes, but the churn seems to be much smaller now.

@mxz297
Copy link
Contributor Author

mxz297 commented Sep 10, 2025

The failed CI seems unrelated?

@mxz297
Copy link
Contributor Author

mxz297 commented Sep 10, 2025

Expected. We need another release for trtllm-gen before that: https://github.com/flashinfer-ai/flashinfer/blob/7c587eb16c9236b7797525411fe8aedea0c9ec05/flashinfer/artifacts.py#L111-L120

By the way, do you know when the trtllm-gen release will be available?

Comment on lines 78 to 79
tl.store(mock_kv_cache_ptr + mock_cache_offset,
dequantized_vals.to(tl.bfloat16))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't hardcode the dtype to bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin triton does not support template dtype, so it is difficult to support generic types here. Do you have some specific types that you want to support beyond bfloat16?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't you just use the dtype of mock_kv_cache?

Copy link
Contributor Author

@mxz297 mxz297 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the kernel, we need to use triton dtype (tl.bfloat16 or tl.float16) instead of tensor dtype (torch.bfloat16 or torch.float16). I may be wrong but my understanding is that within a triton kernel, we do not get the dtype of a tensor. The only way i am aware is to passing in bool variable as kernel input argument to indicate whether certain type is being used. We cannot pass a torch.dtype as a kernel input argument.

In my latest revision, i changed the code to support both bf16 and fp16

(batch_idx * block_table_stride + mock_block_table_idx + 1) *
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
tl.store(mock_kv_cache_ptr + mock_cache_offset,
dequantized_vals.to(tl.bfloat16))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache = torch.empty(new_s,
dtype=torch.bfloat16,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines +901 to +902
mock_kv_cache = kv_cache_permute
mock_block_table = block_tables_prefill
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't like using the term "mock" here. I realize we need to keep around kv_cache_permute for decode so we can't just reuse the name, but it is misleading to pass into the kernel

batch_size, num_of_page_per_token = block_tables_prefill.shape
s = kv_cache.shape
assert s[1] == 2
assert dequant_dtype in (torch.bfloat16, torch.float16)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin revised to passing in q_dtype, and then support bf16 and fp16 and assert when encountering other types

Comment on lines 79 to 82
if dequant_to_bf16:
dequantized_vals = dequantized_vals.to(tl.bfloat16)
else:
dequantized_vals = dequantized_vals.to(tl.float16)
Copy link
Member

@mgoin mgoin Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this external dtype branching logic, just get the element type from the output directly i.e.

dequantized_vals = dequantized_vals.to(mock_kv_cache_ptr.type.element_ty)

Copy link
Contributor Author

@mxz297 mxz297 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Revised. Thanks!

@houseroad
Copy link
Collaborator

houseroad commented Sep 11, 2025

We can create some issue to track moving this logic to trtllm?

@simon-mo simon-mo merged commit e42af78 into vllm-project:main Sep 11, 2025
39 of 41 checks passed
@mxz297 mxz297 deleted the export-D81623711 branch September 11, 2025 21:29
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Comment on lines +351 to +354
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we could get rid of this after #26146

xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants