-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Enable FlashInfer V1 FP8 kv cache #17005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: mgoin <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
I installed |
|
This pull request has merge conflicts that must be resolved before it can be |
|
I tried this PR out on my 3090Ti on latest main and get an illegal memory access: InvocationLogThis is when trying to run |
|
I don't get the invalid memory access with Really looking forward to fp8 on V1 for non-Hopper devices. |
|
With the changes in this pr to force enable FlashInfer v1 with FP8 kv cache enabled, I'm seeing error below in moe fp8 models, to me llama3 fp8 works fine. Script to reproduce it: |
This PR tries to fix an issue that occured while enabling fp8 kv-cache support for vllm (vllm-project/vllm#17005). The issue was that in an generated inc file (e.g. in my case flashinfer/100/generated/batch_decode_with_kv_cache_dtype_q_bf16_dtype_kv_u8_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False/batch_decode_config.inc ) we declared DTypeKV to be uint8_t, shown as below: ``` using DTypeKV = uint8_t; ... struct Params { ... using DTypeKV = DTypeKV; ... }; ``` Consequently, when we instantiate the vec_ from cast_load_impl defined in vec_dtypes.cuh, i.e. ``` vec_t<src_float_t, vec_size> tmp; ``` src_float_t is instantiated to uint8_t through template parameter T=Params::DTypeKV, where Params::DTypeKV is uint8_t. Because vec_t doesn't have any specialization for uint8_t, we ended up with the following ptxas error: ``` ptxas fatal : Unresolved extern function '_ZN10flashinfer5vec_tIhLm16EE4loadEPKh' ``` The fix is to add a specialization for uint8_t. However, this may not be the right fix, because the root cause might be that we shouldn't generate ```using DTypeKV = uint8_t;``` in the first place.
This PR tries to fix an issue that occured while enabling fp8 kv-cache support for vllm (vllm-project/vllm#17005). The issue was that in an generated inc file (e.g. in my case flashinfer/100/generated/batch_decode_with_kv_cache_dtype_q_bf16_dtype_kv_u8_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False/batch_decode_config.inc ) we declared DTypeKV to be uint8_t, shown as below: ``` using DTypeKV = uint8_t; ... struct Params { ... using DTypeKV = DTypeKV; ... }; ``` Consequently, when we instantiate the vec_ from cast_load_impl defined in vec_dtypes.cuh, i.e. ``` vec_t<src_float_t, vec_size> tmp; ``` src_float_t is instantiated to uint8_t through template parameter T=Params::DTypeKV, where Params::DTypeKV is uint8_t. Because vec_t doesn't have any specialization for uint8_t, we ended up with the following ptxas error: ``` ptxas fatal : Unresolved extern function '_ZN10flashinfer5vec_tIhLm16EE4loadEPKh' ``` The fix is to add a specialization for uint8_t. However, this may not be the right fix, because the root cause might be that we shouldn't generate ```using DTypeKV = uint8_t;``` in the first place. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
The ptxas error was fixed in flashinfer-ai/flashinfer#1234 However, the lm_eval result with gsm8k still looks very off: Looking into this. |
This PR fixed fp8 kv-cache issues for the FlashInfer attn backend. Along with vllm-project#17005, got reasonable eval results on B200: ``` $ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,kv_cache_dtype=fp8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto ... vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,kv_cache_dtype=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7779|± |0.0114| | | |strict-match | 5|exact_match|↑ |0.7582|± |0.0118| ``` compared with bf16 kv-cache ``` $ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto ... |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7756|± |0.0115| | | |strict-match | 5|exact_match|↑ |0.7498|± |0.0119| ``` Tags: Signed-off-by: Yang Chen <[email protected]>
@Daisy-Ma-coder I am able to repro the failure on H100. On B200, looks like it works with the fix (#20746) now. I got the output below using your example: |
got it, thanks! I'm on H200s, so likely I'll still run into same error with your fix, but I can try it out. |
Yeah, it's very likely you will still see the same issue on H200. I will investigate it in a couple of days. |
This is for resolving an issue encountered while enabling fp8 kv-cache support in the flashinfer backend: vllm-project/vllm#17005 (comment) The root cause seems to be that we do not have native fp8 kv-cache support for prefill. The failure that we hit, i.e. ``` static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); ``` is just a fact that our prefill kernels do not instantiate cute::GMMA::rs_op_selector with the correct layout for fp8, which requires k-major for B matrix: https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/attention/hopper/kernel_traits.cuh#L78 Note that we cannot simply assign k-major when DTypeKV is of fp8. There are more to fix to correctly support fp8 kv-cache in the kernel. So, this comes to the workaround in this PR, where we convert k and v to q_data_type if they are fp8 but q is not. We can do this from vllm, but I think it seems to be better to put it in flashinfer, because we do not require any changes to the customer code if we support fp8 kv-cache for prefill in a better way. Also please note that I am not 100% sure if this is an appropriate fix, particularly I am not familiar with flashinfer's code base. Originally, I was a bit worried about the impact to other kv-cache related things such as _paged_kv_indptr and _kv_indptr_buf. It seems to be fine to me after reading through the relevant code in prefill.py and hopper/prefill_sm90.cuh. Last note - eventually, I think we might need to support fp8 kv-cache for prefill more appropriately.
|
This pull request has merge conflicts that must be resolved before it can be |
Unfortunately this seems to fail on B200