-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support #17110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support #17110
Conversation
Co-authored-by: Gregory Shtrasberg <[email protected]> Signed-off-by: tjtanaa <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I would like us to contain all of the 1 stage vs 2 stage dispatching logic in the rocm_aiter_fused_experts
function if posible
is_rocm_aiter_moe_enabled, shuffle_weights) | ||
|
||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() | ||
self.rocm_aiter_2stage_moe_enabled = is_rocm_aiter_2stage_moe_enabled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this variable is only used to warn that we are falling back to a different implementation. Let's remove it to simplify the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as below
#17110 (comment)
There is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.
a2_scale: Optional[torch.Tensor] = None, | ||
block_shape: Optional[List[int]] = None, | ||
allow_deep_gemm: bool = False, | ||
use_ck_moe_2stages: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of passing this boolean around, can you check the environment variable here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SageMoore There is an RFC that highlights Accessing envs.ENV
is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.
block_shape: Optional[List[int]] = None, | ||
allow_deep_gemm: bool = False) -> torch.Tensor: | ||
allow_deep_gemm: bool = False, | ||
use_ck_moe_2stages: bool = False) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I doesn't look like you need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. this is not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the argument allow_deep_gemm
. use_ck_moe_2stages
is kept as there is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .
Thus, all the env are only invoked and stored as a property of the class during initialization stage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To pass the pre-commit tests of file vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
, we have adjusted the logic of assignment of fused_experts function to fused_experts_func
to become rocm_aiter_fused_experts_func
, following the approach in vllm/attention/backends/rocm_flash_attn.py
, where the attention functions are assigned to different property name: self.fa_attn_func
, self.sdpa_attn_func
and self.triton_attn_func
This also allows us to clean up the unused arguments of the function in rocm_aiter_fused_experts
(vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
). Note that the expert_map
has been removed as a bugfix, expert_map
(integer ID of the experts) in vLLM and expert_mask
(boolean mask of active experts on current GPU) in the AITER ops are different. The current rocm_aiter_fused_experts
has removed the expert_map
argument and encourage to add it back when enabling EP using AITER in future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verified with llama4 bf16 128e model.
LGTM in general. Agreed the environment variable part can be simplified.
def is_rocm_aiter_2stage_moe_enabled() -> bool: | ||
return current_platform.is_rocm() \ | ||
and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \ | ||
and envs.VLLM_ROCM_USE_AITER_MOE \ | ||
and envs.VLLM_ROCM_USE_AITER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems too much to check 3 environment variables. envs.VLLM_ROCM_USE_AITER_2STAGE_MOE is enough as it is only used when the other two are already true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The is_rocm_aiter_2stage_moe_enabled()
has been removed since envs.VLLM_ROCM_USE_AITER_2STAGE_MOE
is being called in the layer class during initialization only, not in the forward pass.
…argument clean up Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to propose an alternative implementation. Instead of passing this new environment variable around, let's just keep the existing VLLM_ROCM_USE_AITER_MOE
variable and add heuristics to the code that decide when to use the one stage vs the two stage implementation.
It looks like the two stage kernel is generally faster but doesn't support dynamic quantization? Let's rework this PR so that we always use the two stage implementation when it's supported and fallback to the one stage implementation when it isn't. We don't need any logging when this fallback occurs.
If we make the change that I suggested, it looks like you can simplify the changes in fp8.py down to just some layout changes for the weights. fused_moe.py should get a bit simpler as well.
What do you think?
@SageMoore Can we get this merged and do the rework as a follow up PR?
@SageMoore It would be better to do a different implementation in a follow-up PR. We would like to have this PR merged as this improves performance for a bunch of models. |
Signed-off-by: tjtanaa <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
@SageMoore Thank you for the suggestion, we had already cleared up this env variable and made simplifications to the fused moe functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable. Thanks for cleaning up the dispatching logic!
Signed-off-by: tjtanaa <[email protected]>
262279a
to
cba7244
Compare
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]> Co-authored-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
Description
This PR introduces the use of
ck_moe_2stages
fromROCm/aiter
.Performance gain
meta-llama/Llama-4-Maverick-17B-128E-Instruct
V1 Engine
Input: Output = 1000:1000
amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV
V1 Engine
Input: Output = 1000:1000
mistralai/Mixtral-8x7B-v0.1
V1 Engine
Input: Output = 1000:1000
Lm_eval
Dataset: GSM8K
Engine: V1
AITER 1 stage amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV fallback to 2 stage moe
vllm (pretrained=amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 2 stage amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV
[2025-04-30 02:18:36] INFO evaluation_tracker.py:272: Output path not provided, skipping saving results aggregated
vllm (pretrained=amd/Mixtral-8x7B-Instruct-v0.1-FP8-KV,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 1 stage meta-llama/Llama-4-Scout-17B-16E-Instruct
vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 2 stage meta-llama/Llama-4-Scout-17B-16E-Instruct
vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 1 stage mistralai/Mixtral-8x7B-Instruct-v0.1 FP16
[2025-05-05 09:29:26] INFO evaluation_tracker.py:272: Output path not provided, skipping saving results aggregated
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 2 stage mistralai/Mixtral-8x7B-Instruct-v0.1FP16
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 1 stage mistralai/Mixtral-8x7B-Instruct-v0.1 Dynamic FP8
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,quantization=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
AITER 2 stage mistralai/Mixtral-8x7B-Instruct-v0.1 Dynamic FP8 (fall back to 1 stage)
vllm (pretrained=mistralai/Mixtral-8x7B-Instruct-v0.1,tensor_parallel_size=2,max_model_len=10000,quantization=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Updated unit tests
Given that the
rocm_aiter_fused_experts
has been removed from the functiondispatch_fused_experts_func
invllm/model_executor/layers/fused_moe/fused_moe.py
.Future TODO: