-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[ROCm] Enable Triton ScaledMM fallback + kernel selection fix #26668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an issue where triton_scaled_mm
was not being used on ROCm by fixing the kernel selection logic. It correctly adds TritonScaledMMLinearKernel
as a fallback for both ROCm and CUDA, and introduces an is_supported
check to ensure kernels are compatible with the current platform. The changes are accompanied by a new integration test to verify the fix.
My review focuses on improving the robustness of the kernel selection. I've suggested making the get_min_capability
check in the Triton kernel platform-aware to prevent it from being selected on unsupported ROCm hardware. Additionally, I've pointed out a confusing try-except
block in the new test file that should be simplified for clarity and to avoid masking potential errors.
def get_min_capability(cls) -> int: | ||
return 75 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_min_capability
method returns a hardcoded value of 75
, which corresponds to a CUDA compute capability (Turing architecture). This is confusing and potentially incorrect when running on ROCm, as it might allow this kernel to be selected on older ROCm hardware that doesn't fully support Triton, leading to runtime errors. To make the check more robust and clear, this method should be platform-aware, returning the appropriate minimum capability for both CUDA and ROCm.
def get_min_capability(cls) -> int:
if current_platform.is_rocm():
# gfx90a (MI200) is the first ROCm GPU with widespread Triton support.
return 90
# Volta support is limited; Turing (SM 7.5) is a safer baseline for CUDA.
return 75
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring this as I did not touch this code, I can take this up in followup PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
99018da
to
4d3a612
Compare
4d3f427
to
d0d088d
Compare
… entry Signed-off-by: Shivam <[email protected]> Signed-off-by: Shivam <[email protected]>
d0d088d
to
9036316
Compare
Purpose
This PR fixes #14397 :
triton_scaled_mm
never used on ROCm.On ROCm,
TritonScaledMMLinearKernel
was never selected because:is_supported()
torch.scaled_mm
This PR:
is_supported()
Test Plan (what I actually ran)
Minimal smoke test (ROCm-mocked, no native ops):
What the script does (in-process mocks):
VLLM_TARGET_DEVICE=rocm
vllm._custom_ops
(no compiled extensions needed)torch
CPU wheel onlychoose_scaled_mm_linear_kernel(...)
TritonScaledMMLinearKernel
is chosenTest Results
ROCm mocked + AITriton disabled → Triton fallback selected: PASS
Output includes: