Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jul 20, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Introduce AITER HIP Block Scale Quantization kernel.
Verified on AITER Commit: 916bf3c

Test Plan

  • Accuracy: Run lm_eval on gsm8k dataset
  • Perf gain: Compare before and after of DeepSeek-R1
    • ISL: 1000
    • OSL: 1000
    • Dataset: Random

Test Result

Accuracy

vllm (pretrained=deepseek-ai/DeepSeek-R1,tensor_parallel_size=8,max_model_len=32768,block_size=1,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9545 ± 0.0057
strict-match 5 exact_match 0.9553 ± 0.0057

Perf Gain

Metric Before Block Scale Quant After Block Scale Quant Change % Improvement
General Performance
Successful requests 500 500 0 0%
Benchmark duration (s) 112.24 106.62 -5.62 5.0%
Total input tokens 492,837 492,837 0 0%
Request throughput (req/s) 4.45 4.69 +0.24 5.4%
Output token throughput (tok/s) 1,182.46 1,190.88 +8.42 0.7%
Total token throughput (tok/s) 5,573.19 5,813.09 +239.90 4.3%
Time to First Token (TTFT)
Mean TTFT (ms) 18,676.18 17,337.52 -1,338.66 7.2%
Median TTFT (ms) 19,345.05 17,662.82 -1,682.23 8.7%
P99 TTFT (ms) 34,291.45 31,478.22 -2,813.23 8.2%
Time per Output Token (TPOT)
Mean TPOT (ms) 1,290.57 1,158.51 -132.06 10.2%
Median TPOT (ms) 2,051.98 1,880.27 -171.71 8.4%
P99 TPOT (ms) 2,406.51 2,358.88 -47.63 2.0%
Inter-token Latency (ITL)
Mean ITL (ms) 67.74 65.33 -2.41 3.6%
Median ITL (ms) 43.08 42.12 -0.96 2.2%
P99 ITL (ms) 2,047.79 1,880.06 -167.73 8.2%

(Optional) Documentation Update

Signed-off-by: tjtanaa <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the rocm Related to AMD ROCm label Jul 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a HIP block quantization kernel from AITER for ROCm to improve performance. The changes are mainly in fp8_utils.py to dispatch to this new kernel. My review identifies a critical issue with the implementation's robustness. The new code for the AITER kernel is added at the module level with a conditional import that can lead to runtime crashes (NameError or ImportError) if misconfigured or if the optional aiter dependency is missing. I've recommended a more robust approach to handle the optional dependency and variable definitions.

Comment on lines 85 to 91
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.supports_fp8()):

import aiter as rocm_aiter
from aiter import get_hip_quant

aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This logic for conditionally defining aiter_per1x128_quant and rocm_aiter at the module level is fragile and can lead to runtime errors.

  1. NameError Risk: If the condition on line 85 is false, aiter_per1x128_quant and rocm_aiter are never defined. However, the function apply_w8a8_block_fp8_linear uses these variables, guarded by the use_aiter_and_is_supported flag. If the logic to set this flag at the call site diverges even slightly from the condition here, it will result in a NameError at runtime. It's much safer to unconditionally define these variables (e.g., to None) before this if block.

  2. Unguarded Import: The import aiter on line 88 is not wrapped in a try...except ImportError. If VLLM_ROCM_USE_AITER is enabled but aiter is not installed, the application will crash on import. This should be handled gracefully to prevent the entire service from failing.

A robust implementation would define aiter_per1x128_quant and rocm_aiter unconditionally (e.g., as None) and use a try...except block to handle the optional import, updating the variables upon success. This would prevent both NameError and ImportError crashes.

Comment on lines +188 to +190
if use_aiter_and_is_supported:
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to aiter_per1x128_quant is not guarded by a try-except block, which can lead to a crash if the aiter module fails to load. Wrap this call in a try-except block to handle potential import errors gracefully.

try:
            q_input, x_scale = aiter_per1x128_quant(
                input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
        except NameError as e:
            raise ImportError("AITER is not properly installed or configured.") from e

Comment on lines 85 to 86
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.supports_fp8()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be calling what sets use_aiter_and_is_supported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you feel about moving this import into the block below that does the call?

Copy link
Contributor Author

@tjtanaa tjtanaa Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin @SageMoore Thank you so much for the feedback. I think I will try to start implementing this #21504 here so that we can reuse the conditions. There is a problem here is that the use_aiter_and_is_supported is not able to use to call the import.

I think moving the import into the block is not that great to avoid always invoking the get kernel function get_hip_quant.

I will let you guys know again after I fix the conditions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates: I will implement #21504 after this PR.

@mgoin
Thank you. Following the use_aiter_and_is_supported in fp8.py I have updated to

    if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
-            and current_platform.supports_fp8()):
+            and current_platform.is_fp8_fnuz()):

@SageMoore
I would prefer to keep the current implementation to avoiding repeated dictionary lookups is important for performance.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thanks for the update, I think this is okay for now. I would like these conditions to be less local and more centralized, as in having one function to check for this block linear case. But we can refactor that later

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 27, 2025
@mgoin mgoin enabled auto-merge (squash) July 27, 2025 16:53
@mgoin mgoin merged commit e626d28 into vllm-project:main Jul 28, 2025
69 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
HsChen-sys pushed a commit to HsChen-sys/vllm that referenced this pull request Aug 1, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants