Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Apr 24, 2025

Description

This PR introduces the use of ck_moe_2stages from ROCm/aiter.

Performance gain

meta-llama/Llama-4-Maverick-17B-128E-Instruct

V1 Engine
Input: Output = 1000:1000

Metric No AITER AITER 1 Stage AITER 2 Stage % Gain (2 Stage vs Baseline)
Benchmark Duration (s) 67.39 64.84 59.52 11.7%
Request Throughput (req/s) 7.42 7.71 8.40 13.2%
Request Goodput (req/s) 0.61 0.71 1.01 65.6%
Output Token Throughput (tok/s) 6,135.14 6,437.38 6,898.15 12.4%
Total Token Throughput (tok/s) 13,469.17 14,059.20 15,201.54 12.9%
Mean TTFT (ms) 7,285.70 7,133.57 6,187.27 15.1%
Median TTFT (ms) 7,231.37 6,962.17 6,006.79 16.9%
Mean TPOT (ms) 52.90 50.15 47.16 10.9%
Median ITL (ms) 34.30 31.46 29.33 14.5%

amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV

V1 Engine
Input: Output = 1000:1000

Metric No AIter AIter 2Stage Improvement
Benchmark duration (s) 80.38 76.48 4.85% faster
Total generated tokens 436,389 482,066 10.47% more
Request throughput (req/s) 6.22 6.54 5.14% higher
Output token throughput (tok/s) 5,429.41 6,303.20 16.09% higher
Total token throughput (tok/s) 11,578.38 12,765.36 10.25% higher
Median TTFT (ms) 9,477.47 9,574.82 1.03% slower
Median TPOT (ms) 50.20 46.73 6.91% faster
Median ITL (ms) 39.45 36.85 6.59% faster

mistralai/Mixtral-8x7B-v0.1

V1 Engine
Input: Output = 1000:1000

Metric No AITER AITER 1-Stage AITER 2-Stage % Gain (2-Stage vs No AITER) % Gain (2-Stage vs 1-Stage)
Benchmark Duration (s) 108.38 122.76 100.61 7.2% 18.0%
Request Throughput (req/s) 4.61 4.07 4.97 7.8% 22.1%
Request Goodput (req/s) 0.02 0.01 0.02 0.0% 100.0%
Output Token Throughput (tok/s) 4445.47 3907.71 4773.95 7.4% 22.2%
Total Token Throughput (tok/s) 9005.44 7933.76 9686.44 7.6% 22.1%
Mean TTFT (ms) 13133.12 20749.40 12677.77 3.5% 38.9%
Median TTFT (ms) 12911.29 20420.47 12429.12 3.7% 39.1%
P99 TTFT (ms) 27070.03 41767.87 25574.30 5.5% 38.8%
Mean TPOT (ms) 83.69 104.94 78.76 5.9% 24.9%
Median TPOT (ms) 64.66 73.77 61.07 5.6% 17.2%
P99 TPOT (ms) 399.80 642.01 393.91 1.5% 38.6%
Mean ITL (ms) 62.21 70.75 58.46 6.0% 17.4%
Median ITL (ms) 52.15 55.18 48.90 6.2% 11.4%
P99 ITL (ms) 403.70 645.63 396.17 1.9% 38.6%

Lm_eval

Dataset: GSM8K
Engine: V1

AITER 1 stage amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV fallback to 2 stage moe

vllm (pretrained=amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6528 ± 0.0131
strict-match 5 exact_match 0.6490 ± 0.0131

AITER 2 stage amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV

[2025-04-30 02:18:36] INFO evaluation_tracker.py:272: Output path not provided, skipping saving results aggregated
vllm (pretrained=amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6422 ± 0.0132
strict-match 5 exact_match 0.6391 ± 0.0132

AITER 1 stage meta-llama/Llama-4-Scout-17B-16E-Instruct

vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9143 ± 0.0077
strict-match 5 exact_match 0.8961 ± 0.0084

AITER 2 stage meta-llama/Llama-4-Scout-17B-16E-Instruct

vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9227 ± 0.0074
strict-match 5 exact_match 0.9060 ± 0.0080

AITER 1 stage mistralai/Mixtral-8x7B-Instruct-v0.1 FP16

[2025-05-05 09:29:26] INFO evaluation_tracker.py:272: Output path not provided, skipping saving results aggregated
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6429 ± 0.0132
strict-match 5 exact_match 0.6399 ± 0.0132

AITER 2 stage mistralai/Mixtral-8x7B-Instruct-v0.1FP16

vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6513 ± 0.0131
strict-match 5 exact_match 0.6475 ± 0.0132

AITER 1 stage mistralai/Mixtral-8x7B-Instruct-v0.1 Dynamic FP8

vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,quantization=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.5777 ± 0.0136
strict-match 5 exact_match 0.5701 ± 0.0136

AITER 2 stage mistralai/Mixtral-8x7B-Instruct-v0.1 Dynamic FP8 (fall back to 1 stage)

vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,quantization=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.5792 ± 0.0136
strict-match 5 exact_match 0.5701 ± 0.0136

Updated unit tests

Given that the rocm_aiter_fused_experts has been removed from the function dispatch_fused_experts_func in vllm/model_executor/layers/fused_moe/fused_moe.py.

Future TODO:

# TODO: remove @cache to allow the use case
# where there are multiple instances of LLM
# running on the same process/program with different
# configurations. Blockage is that dispatch_topk_func()
# in vllm/model_executor/layers/fused_moe/fused_moe.py
# uses this function in critical path.
@cache
def is_rocm_aiter_moe_enabled() -> bool:
    return current_platform.is_rocm() \
        and envs.VLLM_ROCM_USE_AITER_MOE \
        and envs.VLLM_ROCM_USE_AITER

Co-authored-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@tjtanaa tjtanaa marked this pull request as ready for review May 5, 2025 13:01
@hongxiayang hongxiayang added the rocm Related to AMD ROCm label May 5, 2025
Copy link

mergify bot commented May 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 5, 2025
@mergify mergify bot removed the needs-rebase label May 6, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would like us to contain all of the 1 stage vs 2 stage dispatching logic in the rocm_aiter_fused_experts function if posible

is_rocm_aiter_moe_enabled, shuffle_weights)

self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_2stage_moe_enabled = is_rocm_aiter_2stage_moe_enabled()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this variable is only used to warn that we are falling back to a different implementation. Let's remove it to simplify the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as below
#17110 (comment)

There is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
use_ck_moe_2stages: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing this boolean around, can you check the environment variable here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore There is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
allow_deep_gemm: bool = False,
use_ck_moe_2stages: bool = False) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I doesn't look like you need this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. this is not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the argument allow_deep_gemm. use_ck_moe_2stages is kept as there is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To pass the pre-commit tests of file vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py , we have adjusted the logic of assignment of fused_experts function to fused_experts_func to become rocm_aiter_fused_experts_func, following the approach in vllm/attention/backends/rocm_flash_attn.py, where the attention functions are assigned to different property name: self.fa_attn_func , self.sdpa_attn_func and self.triton_attn_func

This also allows us to clean up the unused arguments of the function in rocm_aiter_fused_experts (vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py). Note that the expert_map has been removed as a bugfix, expert_map (integer ID of the experts) in vLLM and expert_mask (boolean mask of active experts on current GPU) in the AITER ops are different. The current rocm_aiter_fused_experts has removed the expert_map argument and encourage to add it back when enabling EP using AITER in future PR.

Copy link
Collaborator

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified with llama4 bf16 128e model.
LGTM in general. Agreed the environment variable part can be simplified.

Comment on lines 19 to 23
def is_rocm_aiter_2stage_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \
and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too much to check 3 environment variables. envs.VLLM_ROCM_USE_AITER_2STAGE_MOE is enough as it is only used when the other two are already true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_rocm_aiter_2stage_moe_enabled() has been removed since envs.VLLM_ROCM_USE_AITER_2STAGE_MOE is being called in the layer class during initialization only, not in the forward pass.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to propose an alternative implementation. Instead of passing this new environment variable around, let's just keep the existing VLLM_ROCM_USE_AITER_MOE variable and add heuristics to the code that decide when to use the one stage vs the two stage implementation.

It looks like the two stage kernel is generally faster but doesn't support dynamic quantization? Let's rework this PR so that we always use the two stage implementation when it's supported and fallback to the one stage implementation when it isn't. We don't need any logging when this fallback occurs.

If we make the change that I suggested, it looks like you can simplify the changes in fp8.py down to just some layout changes for the weights. fused_moe.py should get a bit simpler as well.

What do you think?

@hongxiayang
Copy link
Collaborator

I'd like to propose an alternative implementation. Instead of passing this new environment variable around, let's just keep the existing VLLM_ROCM_USE_AITER_MOE variable and add heuristics to the code that decide when to use the one stage vs the two stage implementation.

It looks like the two stage kernel is generally faster but doesn't support dynamic quantization? Let's rework this PR so that we always use the two stage implementation when it's supported and fallback to the one stage implementation when it isn't. We don't need any logging when this fallback occurs.

If we make the change that I suggested, it looks like you can simplify the changes in fp8.py down to just some layout changes for the weights. fused_moe.py should get a bit simpler as well.

What do you think?

@SageMoore Can we get this merged and do the rework as a follow up PR?

I'd like to propose an alternative implementation. Instead of passing this new environment variable around, let's just keep the existing VLLM_ROCM_USE_AITER_MOE variable and add heuristics to the code that decide when to use the one stage vs the two stage implementation.

It looks like the two stage kernel is generally faster but doesn't support dynamic quantization? Let's rework this PR so that we always use the two stage implementation when it's supported and fallback to the one stage implementation when it isn't. We don't need any logging when this fallback occurs.

If we make the change that I suggested, it looks like you can simplify the changes in fp8.py down to just some layout changes for the weights. fused_moe.py should get a bit simpler as well.

What do you think?

@SageMoore It would be better to do a different implementation in a follow-up PR. We would like to have this PR merged as this improves performance for a bunch of models.

Copy link

mergify bot commented May 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 13, 2025
@mergify mergify bot removed the needs-rebase label May 13, 2025
@tjtanaa tjtanaa requested a review from SageMoore May 13, 2025 12:57
@tjtanaa
Copy link
Contributor Author

tjtanaa commented May 13, 2025

I'd like to propose an alternative implementation. Instead of passing this new environment variable around, let's just keep the existing VLLM_ROCM_USE_AITER_MOE variable and add heuristics to the code that decide when to use the one stage vs the two stage implementation.

It looks like the two stage kernel is generally faster but doesn't support dynamic quantization? Let's rework this PR so that we always use the two stage implementation when it's supported and fallback to the one stage implementation when it isn't. We don't need any logging when this fallback occurs.

If we make the change that I suggested, it looks like you can simplify the changes in fp8.py down to just some layout changes for the weights. fused_moe.py should get a bit simpler as well.

What do you think?

@SageMoore Thank you for the suggestion, we had already cleared up this env variable and made simplifications to the fused moe functions.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Thanks for cleaning up the dispatching logic!

@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label May 13, 2025
Signed-off-by: tjtanaa <[email protected]>
@tjtanaa tjtanaa force-pushed the aiter-ck-moe-2-stage branch from 262279a to cba7244 Compare May 13, 2025 14:50
Copy link

mergify bot commented May 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 14, 2025
Signed-off-by: tjtanaa <[email protected]>
@tjtanaa tjtanaa requested a review from WoosukKwon as a code owner May 14, 2025 03:42
Signed-off-by: tjtanaa <[email protected]>
@mergify mergify bot removed the needs-rebase label May 14, 2025
Copy link

mergify bot commented May 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 14, 2025
Signed-off-by: tjtanaa <[email protected]>
@mergify mergify bot removed the needs-rebase label May 14, 2025
@vllm-bot vllm-bot merged commit 612c2ed into vllm-project:main May 14, 2025
63 of 66 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants