Skip to content

Conversation

shivampr
Copy link

@shivampr shivampr commented Oct 13, 2025

Purpose

This PR fixes #14397 : triton_scaled_mm never used on ROCm.

On ROCm, TritonScaledMMLinearKernel was never selected because:

  • Only AITriton path was registered for ROCm
  • CUTLASS path was skipped on ROCm
  • Triton fallback was missing from the dispatch list
  • Kernel selection didn't check is_supported()
  • As a result, ROCm always fell back to slow torch.scaled_mm

This PR:

  • Enables Triton fallback for ROCm when AITriton is not available
  • Adds Triton fallback to CUDA after CUTLASS
  • Implements kernel support checks via is_supported()
  • Adds minimal integration test (no GPU required)

Test Plan (what I actually ran)

Minimal smoke test (ROCm-mocked, no native ops):

# From repo root
PYTHONPATH="$(pwd)" python3 mini_tests/select_triton_rocm.py

What the script does (in-process mocks):

  • Forces VLLM_TARGET_DEVICE=rocm
  • Mocks so ROCm platform is detected
  • Stubs vllm._custom_ops (no compiled extensions needed)
  • Uses torch CPU wheel only
  • Selects kernel via choose_scaled_mm_linear_kernel(...)
  • Asserts TritonScaledMMLinearKernel is chosen

Test Results

  • ROCm mocked + AITriton disabled → Triton fallback selected: PASS
    Output includes:

    Selected kernel: TritonScaledMMLinearKernel
    OK: TritonScaledMMLinearKernel chosen on ROCm fallback.
    

@mergify mergify bot added the rocm Related to AMD ROCm label Oct 13, 2025
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shivampr.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue where triton_scaled_mm was not being used on ROCm by fixing the kernel selection logic. It correctly adds TritonScaledMMLinearKernel as a fallback for both ROCm and CUDA, and introduces an is_supported check to ensure kernels are compatible with the current platform. The changes are accompanied by a new integration test to verify the fix.

My review focuses on improving the robustness of the kernel selection. I've suggested making the get_min_capability check in the Triton kernel platform-aware to prevent it from being selected on unsupported ROCm hardware. Additionally, I've pointed out a confusing try-except block in the new test file that should be simplified for clarity and to avoid masking potential errors.

def get_min_capability(cls) -> int:
return 75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_min_capability method returns a hardcoded value of 75, which corresponds to a CUDA compute capability (Turing architecture). This is confusing and potentially incorrect when running on ROCm, as it might allow this kernel to be selected on older ROCm hardware that doesn't fully support Triton, leading to runtime errors. To make the check more robust and clear, this method should be platform-aware, returning the appropriate minimum capability for both CUDA and ROCm.

    def get_min_capability(cls) -> int:
        if current_platform.is_rocm():
            # gfx90a (MI200) is the first ROCm GPU with widespread Triton support.
            return 90
        # Volta support is limited; Turing (SM 7.5) is a safer baseline for CUDA.
        return 75

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring this as I did not touch this code, I can take this up in followup PR.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@shivampr shivampr force-pushed the rocm-triton-fallback branch 3 times, most recently from 99018da to 4d3a612 Compare October 13, 2025 05:09
@mergify mergify bot removed the needs-rebase label Oct 13, 2025
@shivampr shivampr force-pushed the rocm-triton-fallback branch 3 times, most recently from 4d3f427 to d0d088d Compare October 13, 2025 05:24
@shivampr shivampr force-pushed the rocm-triton-fallback branch from d0d088d to 9036316 Compare October 13, 2025 05:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: triton_scaled_mm never used on ROCm

1 participant