Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jan 23, 2025

Description

This PR implements a faster Custom Paged Attention (CPA) kernel based on mfma16x16x16 instructions.
This feature is from ROCm/vllm (ROCm#372).

End-to-End Performance gain

Model: Llama-3.1-70B-Instruct
Tensor Parallelism: 1
GPU: MI300X

CPA Version Input length Output length KV-cache-dtype Quantization Prompt numbers Req/s Total Tokens/s Output Tokens/s
before changes 128 128 fp8_e4m3 fp8 200 13.05 3340.6 1670.3
before changes 128 256 fp8_e4m3 fp8 200 7.56 2901.31 1934.21
before changes 128 2048 fp8_e4m3 fp8 200 0.78 1698.35 1598.45
before changes 512 128 fp8_e4m3 fp8 200 6.44 4122.57 824.51
before changes 512 256 fp8_e4m3 fp8 200 4.48 3443.46 1147.82
before changes 512 2048 fp8_e4m3 fp8 200 0.66 1696.64 1357.31
before changes ShareGPT fp8_e4m3 fp8 1000 6.22 2574.19 1234.64
optimized 128 128 fp8_e4m3 fp8 200 15.11 3867.75 1933.87
optimized 128 256 fp8_e4m3 fp8 200 9.01 3459.98 2306.65
optimized 128 2048 fp8_e4m3 fp8 200 1.2 2609.04 2455.57
optimized 512 128 fp8_e4m3 fp8 200 7.33 4694.05 938.81
optimized 512 256 fp8_e4m3 fp8 200 5.5 4223.29 1407.76
optimized 512 2048 fp8_e4m3 fp8 200 1.03 2648.55 2118.84
optimized ShareGPT fp8_e4m3 fp8 1000 7.45 3081.14 1477.79

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 23, 2025
vllmellm and others added 2 commits January 23, 2025 08:50
Ported ROCm/vllm changes to upstream vLLM

This commit manually ports changes from ROCm/vllm (ROCm#372) to upstream vLLM.
The original work was done by sanyalington.

Co-authored-by: sanyalington <[email protected]>

Signed-off-by: vllmellm <[email protected]>
@tjtanaa tjtanaa force-pushed the port-rocm-cpa-credit branch 2 times, most recently from 9be5f70 to f57dcb9 Compare January 23, 2025 08:57
@tjtanaa tjtanaa force-pushed the port-rocm-cpa-credit branch from f57dcb9 to 4f71b54 Compare January 23, 2025 09:01
@tjtanaa tjtanaa changed the title [AMD] Faster Custom Paged Attention kernels [ROCm] Faster Custom Paged Attention kernels Jan 23, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jan 23, 2025

Regarding to the API changes of paged_attention in csrc/rocm/torch_bindings.cpp. This change only affects ROCm code path and does not interfere with code path of other platform.

 rocm_ops.def(
      "paged_attention(Tensor! out, Tensor exp_sums,"
      "                Tensor max_logits, Tensor tmp_out,"
      "                Tensor query, Tensor key_cache,"
      "                Tensor value_cache, int num_kv_heads,"
      "                float scale, Tensor block_tables,"
      "                Tensor context_lens, int block_size,"
      "                int max_context_len,"
      "                Tensor? alibi_slopes,"
      "                str kv_cache_dtype,"
      "                float k_scale, float v_scale,"
      "                Tensor? fp8_out_scale,"
      "                int partition_size) -> ()");

Seeking advice on handling the variables fp8_out_scale and partition_size.

Situation: Currently these two variables fp8_out_scale and partition_size has been introduced in the Custom Paged Attention ROCm, but they are not in used by higher level abstractions. They are set to fp8_out_scale=None and partition_size=256. The partition_size=256 has been found experimentally to be a good value for MI300.

Option 1:

  • Remove fp8_out_scale from csrc/rocm/attention.cu
  • Hard code partition_size to be 256 in csrc/rocm/attention.cu.
    This avoid changing the paged_attention API in csrc/rocm/torch_bindings.cpp

Option 2:

  • Keep the variables as is, and mark TODO: for future feature update to remember introducing fp8 scaling strategy for ROCm.
  • Set fp8_out_scale=None and partition_size=256 when calling ops.paged_attention_rocm in vllm/attention/backends/rocm_flash_attn.py

We have implemented Option 1.

@hongxiayang hongxiayang added the rocm Related to AMD ROCm label Jan 23, 2025
@hongxiayang
Copy link
Collaborator

@tjtanaa Please fix the DCO error:
Ensure you have a local copy of your branch by checking out the pull request locally via command line.
In your local branch, run: git rebase HEAD~4 --signoff
Force push your changes to overwrite the branch: git push --force-with-lease origin port-rocm-cpa-credit

…iminate the need for additional argumnets (partition_size and fp8_output_scale) in its api.

Signed-off-by: vllmellm <[email protected]>
Copy link

mergify bot commented Jan 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 24, 2025
@mergify mergify bot removed the needs-rebase label Jan 24, 2025
… and code documentation. updated its unittest to match the correct partition size based on paged attention versions as well as platform type.

Signed-off-by: vllmellm <[email protected]>
@tjtanaa tjtanaa marked this pull request as ready for review January 27, 2025 12:27
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jan 27, 2025

@tjtanaa Please fix the DCO error: Ensure you have a local copy of your branch by checking out the pull request locally via command line. In your local branch, run: git rebase HEAD~4 --signoff Force push your changes to overwrite the branch: git push --force-with-lease origin port-rocm-cpa-credit

@hongxiayang We find that rebasing is hard as we had merged from main. In the process of fixing the DCO, we had to resolve merge conflict twice, and will require us to test everything again. It seems there are ways to override the DCO during merge. Could we get more input from vLLM maintainers about DCO issue.

@mergify mergify bot removed the needs-rebase label Feb 20, 2025
Copy link

mergify bot commented Feb 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 20, 2025
@hongxiayang
Copy link
Collaborator

cc @houseroad This is an important feature that should be merged asap.

Comment on lines 470 to 472
// Use int64_t for arithmetic to prevent overflow
const int64_t vglobal_token_idx =
static_cast<int64_t>(partition_start_token_idx) + vlocal_token_idx;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit that this should would be better as the following, per conversation above

Suggested change
// Use int64_t for arithmetic to prevent overflow
const int64_t vglobal_token_idx =
static_cast<int64_t>(partition_start_token_idx) + vlocal_token_idx;
// Safe to use an int32_t here assuming we are working with < 2 billion tokens
const int32_t vglobal_token_idx = partition_start_token_idx + vlocal_token_idx;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth We have done updating this back to int32, but we are using int type to represent int32_t type. Is that ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think that's fine especially as the rest of the file uses int

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thanks for the contribution! Please merge latest main as there is a conflict.

@mergify mergify bot removed the needs-rebase label Feb 27, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Feb 28, 2025

LGTM now, thanks for the contribution! Please merge latest main as there is a conflict.

@tlrmchlsmth We have resolved the merge conflict. Thank you.

@tlrmchlsmth
Copy link
Member

@tjtanaa could you take a look at the pre-commit? It's failing as well

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
@vllm-bot vllm-bot merged commit 848a643 into vllm-project:main Mar 3, 2025
57 of 60 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
@tjtanaa tjtanaa deleted the port-rocm-cpa-credit branch May 16, 2025 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants