Skip to content

Conversation

@gronsti-amd
Copy link

@gronsti-amd gronsti-amd commented Oct 13, 2025

Purpose

Closes: #26700

On MI355X, W8A8BlockFp8LinearOp chooses the fallback branch _run_triton instead of the more performant branch _run_aiter. This results in suboptimal performance.

Root cause

Incorrect use of the is_fp8_fnuz condition.
The is_fp8_fnuz condition should only be used to check if FP8 indeed uses the FP8-FNUZ (Finite and NaN Only) format. It should not be used as generic condition to disable kernels on platforms older than MI300X.
MI355X does not use the FP8-FNUZ format, and thus is_fp8_fnuz returns False.
https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#fp8-quarter-precision

The issue was introduced in PR #21242 commit [e626d28](https://github.com/vllm-project/vllm/commit/ e626d28)
vllm/model_executor/layers/quantization/utils/fp8_utils.py line 86

Implementation details

  • Replaces use of is_fp8_fnuz with explicit architecture check.
  • When the aiter code branch is enabled, a secondary issue is triggered: Torch Dynamo attempts to compile the AITER JIT wrapper, crashing with Explanation: Dynamo does not know how to trace method __contains__ of class frozenset. To prevent dynamo from trying to compile it, we register the aiter JIT:ed op rocm_aiter_per1x128_quant using direct_register_custom_op.
  • Avoid using functools.partial, because direct_register_custom_op attempts to infer schema, leading to AttributeError: 'functools.partial' object has no attribute '__globals__'
  • Rename aiter_per1x128_quant to rocm_aiter_per1x128_quant for consistency.
  • Rename the kernel selection check check_aiter_fp8_linear_support to use_aiter_fp8_linear. The function checks whether FP8 is both supported and enabled by env flags. Renaming distinguishes from current_platform.supports_aiter_w8a8_block_fp8_linear which only checks for support but does not check the env flags.

Test Plan

export HIP_VISIBLE_DEVICES=0

export VLLM_ROCM_QUICK_REDUCE_QUANTIZATION=FP
export VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
export VLLM_USE_TRITON_FLASH_ATTN=0
export VLLM_ROCM_USE_AITER=1
export VLLM_ROCM_USE_AITER_MHA=1
export SAFETENSORS_FAST_GPU=1

vllm serve --no-enable-prefix-caching -tp 1 "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"

Test Result

input_seq_length output_seq_length max_concurrency tensor_parallel device ('mean_e2el_ms', 'baseline') ('mean_e2el_ms', 'fix-fnuz') ('mean_itl_ms', 'baseline') ('mean_itl_ms', 'fix-fnuz') ('mean_tpot_ms', 'baseline') ('mean_tpot_ms', 'fix-fnuz') ('mean_ttft_ms', 'baseline') ('mean_ttft_ms', 'fix-fnuz') ('output_throughput', 'baseline') ('output_throughput', 'fix-fnuz')
0 1000 500 1 1 MI355X 4291.45 3631.47 7.74421 6.54753 7.74757 6.55164 39.3975 36.4627 128.168 151.46
1 1000 500 4 1 MI355X 4631.57 3911.22 9.30521 7.85047 9.31005 7.85374 51.3681 47.0733 405.734 480.557
2 1000 500 8 1 MI355X 5144.99 4407.81 9.95725 8.52531 9.9824 8.54841 54.6616 49.5116 723.618 845.489
3 1000 500 16 1 MI355X 5982.12 4949.31 10.8214 8.93642 10.9008 8.99791 77.0072 72.8162 1322.58 1596.36
4 1000 500 32 1 MI355X 6999.54 5936.8 12.9079 10.9287 12.9849 10.9916 142.539 131.182 2042.66 2410.07
5 1000 500 64 1 MI355X 9123.87 8111.46 17.1726 15.231 17.7794 15.763 192.252 189.664 3154.04 3566.56
6 1000 500 128 1 MI355X 13427.6 12387.1 25.854 23.8544 27.1634 25.0466 449.528 412.507 4081.78 4455.09
7 10000 1000 1 1 MI355X 8997.24 7713.68 7.93931 6.77895 7.93079 6.7713 273.915 265.307 122.228 142.566
8 10000 1000 4 1 MI355X 10680.6 9223.76 10.4663 8.99885 10.4429 8.96562 376.494 365.231 353.429 409.812
9 10000 1000 8 1 MI355X 13416.7 11882.2 12.7211 11.2308 12.7905 11.2978 405.996 395.786 563.347 637.685
10 10000 1000 16 1 MI355X 19008.1 16842.2 16.8767 14.9087 17.6137 15.5733 585.094 565.107 854.97 967.078
11 10000 1000 32 1 MI355X 28084.1 25756.5 25.3686 23.1897 26.3407 24.0103 1138.13 1113.81 1090.52 1196.1
12 10000 1000 64 1 MI355X 44788 42010.9 41.2436 38.5243 45.8095 42.4373 1864.18 1917.11 1373.15 1471.11
13 10000 1000 128 1 MI355X 77734.9 74216.8 72.6504 69.3913 80.5413 76.9592 4759.14 4514.68 1520.69 1599.34

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the rocm Related to AMD ROCm label Oct 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a performance issue on MI355X by using the appropriate architecture check for AITER kernel support, instead of relying on the FP8 format. It also proactively fixes a Torch Dynamo compilation crash by registering the AITER JIT op. The changes are well-reasoned and improve both performance and compatibility. I've identified one area for improvement regarding code duplication in the platform-specific checks, which can be refactored for better maintainability.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

- Note that the `is_fp8_fnuz` condition should only be used to check for the
  FP8 format, not as a generic check for MI300 platform.
- Register the aiter JIT:ed op using direct_register_custom_op, to prevent dynamo from trying to compile it.
- Avoid using functools.partial, because direct_register_custom_op attempts to infer schema, leading to
    AttributeError: 'functools.partial' object has no attribute '__globals__'
- Rename aiter_per1x128_quant to rocm_aiter_per1x128_quant for consistency
- Rename kernel selection check to use_aiter_fp8_linear.
  The function checks whether FP8 is both supported and enabled by env flags.
  Renaming from check_aiter_fp8_linear_support to use_aiter_fp8_linear,
  in order to distinguish from current_platform.supports_aiter_w8a8_block_fp8_linear
  which only checks for support but does not check the env flags.
- Remove duplication between architecture check methods.
  There was code duplication between use_custom_allreduce and the newly added
  supports_aiter_w8a8_block_fp8_linear. Both methods perform the same check for
  MI300/MI350 series GPUs. To improve maintainability and avoid potential
  inconsistencies in the future, this logic was extracted into a shared
  private helper method `arch_is_in`.
  The helper method was applied to the other architecture checks of the same form.
  Thanks @gemini-code-assist for the suggestion!

Signed-off-by: Stig-Arne Grönroos <[email protected]>
@gronsti-amd gronsti-amd force-pushed the fix/incorrect-is-fp8-fnuz branch from 07ccfb3 to 0baa31d Compare October 13, 2025 12:49
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Thanks for the contribution!

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 24, 2025
@gshtras gshtras enabled auto-merge (squash) October 24, 2025 14:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][ROCm]: W8A8BlockFp8LinearOp does not use AITER on MI355X

3 participants