-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[ROCm]: W8A8BlockFp8LinearOp should use AITER on MI355X #26701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm]: W8A8BlockFp8LinearOp should use AITER on MI355X #26701
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a performance issue on MI355X by using the appropriate architecture check for AITER kernel support, instead of relying on the FP8 format. It also proactively fixes a Torch Dynamo compilation crash by registering the AITER JIT op. The changes are well-reasoned and improve both performance and compatibility. I've identified one area for improvement regarding code duplication in the platform-specific checks, which can be refactored for better maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
- Note that the `is_fp8_fnuz` condition should only be used to check for the
FP8 format, not as a generic check for MI300 platform.
- Register the aiter JIT:ed op using direct_register_custom_op, to prevent dynamo from trying to compile it.
- Avoid using functools.partial, because direct_register_custom_op attempts to infer schema, leading to
AttributeError: 'functools.partial' object has no attribute '__globals__'
- Rename aiter_per1x128_quant to rocm_aiter_per1x128_quant for consistency
- Rename kernel selection check to use_aiter_fp8_linear.
The function checks whether FP8 is both supported and enabled by env flags.
Renaming from check_aiter_fp8_linear_support to use_aiter_fp8_linear,
in order to distinguish from current_platform.supports_aiter_w8a8_block_fp8_linear
which only checks for support but does not check the env flags.
- Remove duplication between architecture check methods.
There was code duplication between use_custom_allreduce and the newly added
supports_aiter_w8a8_block_fp8_linear. Both methods perform the same check for
MI300/MI350 series GPUs. To improve maintainability and avoid potential
inconsistencies in the future, this logic was extracted into a shared
private helper method `arch_is_in`.
The helper method was applied to the other architecture checks of the same form.
Thanks @gemini-code-assist for the suggestion!
Signed-off-by: Stig-Arne Grönroos <[email protected]>
07ccfb3 to
0baa31d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable. Thanks for the contribution!
Purpose
Closes: #26700
On MI355X,
W8A8BlockFp8LinearOpchooses the fallback branch_run_tritoninstead of the more performant branch_run_aiter. This results in suboptimal performance.Root cause
Incorrect use of the
is_fp8_fnuzcondition.The
is_fp8_fnuzcondition should only be used to check if FP8 indeed uses the FP8-FNUZ (Finite and NaN Only) format. It should not be used as generic condition to disable kernels on platforms older than MI300X.MI355X does not use the FP8-FNUZ format, and thus
is_fp8_fnuzreturns False.https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#fp8-quarter-precision
The issue was introduced in PR #21242 commit [e626d28](https://github.com/vllm-project/vllm/commit/ e626d28)
vllm/model_executor/layers/quantization/utils/fp8_utils.py line 86
Implementation details
is_fp8_fnuzwith explicit architecture check.Explanation: Dynamo does not know how to trace method __contains__ of class frozenset. To prevent dynamo from trying to compile it, we register the aiter JIT:ed oprocm_aiter_per1x128_quantusingdirect_register_custom_op.functools.partial, because direct_register_custom_op attempts to infer schema, leading toAttributeError: 'functools.partial' object has no attribute '__globals__'aiter_per1x128_quanttorocm_aiter_per1x128_quantfor consistency.check_aiter_fp8_linear_supporttouse_aiter_fp8_linear. The function checks whether FP8 is both supported and enabled by env flags. Renaming distinguishes fromcurrent_platform.supports_aiter_w8a8_block_fp8_linearwhich only checks for support but does not check the env flags.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.