Skip to content

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Apr 3, 2025

Similar to #13236, but now to bring the moe_wna16 kernel to models in compressed-tensors format. This is crucial for models with many experts to have good performance or require bfloat16 dtype. We should still move to using the quantization/kernel/ interface to implement this kernel selection properly, but this should be enough to unblock evals on large moes for ct format models

Validated on DeepSeek-R1:

lm_eval --model vllm --model_args pretrained=daslab-testing/DeepSeek-R1-GPTQ-4b-128g-act_order-mse_scale,add_bos_token=True,tensor_parallel_size=8,max_model_len=10000 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Processed prompts: 100%|██████████████| 1319/1319 [02:38<00:00,  8.33it/s, est. speed input: 7271.10 toks/s, output: 845.18 toks/s]
vllm (pretrained=daslab-testing/DeepSeek-R1-GPTQ-4b-128g-act_order-mse_scale,add_bos_token=True,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9553|±  |0.0057|

Comparative scores with Marlin and triton kernels on Mixtral (with support for bfloat16 now):

CompressedTensorsWNA16MarlinMoEMethod
Processed prompts: 100%|██████████████| 1319/1319 [03:44<00:00,  5.88it/s, est. speed input: 6254.14 toks/s, output: 776.47 toks/s
vllm (pretrained=neuralmagic-ent/Mixtral-8x7B-v0.1-quantized.w4a16,add_bos_token=True,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5754|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.5762|±  |0.0136|


CompressedTensorsWNA16MoEMethod
Processed prompts: 100%|██████████████| 1319/1319 [06:07<00:00,  3.59it/s, est. speed input: 3818.91 toks/s, output: 471.17 toks/s]
vllm (pretrained=neuralmagic-ent/Mixtral-8x7B-v0.1-quantized.w4a16,add_bos_token=True,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5785|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.5777|±  |0.0136|


CompressedTensorsWNA16MoEMethod
Processed prompts: 100%|█████████████████████████| 1319/1319 [06:07<00:00,  3.59it/s, est. speed input: 3815.63 toks/s, output: 472.43 toks/s]
vllm (pretrained=neuralmagic-ent/Mixtral-8x7B-v0.1-quantized.w4a16,add_bos_token=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5747|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.5724|±  |0.0136|

Copy link

github-actions bot commented Apr 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mgoin mgoin changed the title Use moe_wna16 kernel for compressed tensors wna16 moe models [Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models Apr 5, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 5, 2025
@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend multi-modality Related to multi-modality (#4194) structured-output v1 tpu Related to Google TPUs labels Apr 9, 2025
@mgoin mgoin force-pushed the compressed-tensors-moe-wna16 branch from ccc1cf5 to 094dd92 Compare April 9, 2025 18:54
@mgoin mgoin removed documentation Improvements or additions to documentation structured-output frontend tpu Related to Google TPUs ci/build labels Apr 9, 2025
@mergify mergify bot removed the tpu Related to Google TPUs label Apr 9, 2025
@mgoin mgoin added performance Performance-related issues quantization and removed v1 multi-modality Related to multi-modality (#4194) labels Apr 9, 2025
@luccafong luccafong self-requested a review April 9, 2025 19:27
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the perf gap on marlin when >=16

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It starts off low, basically equal at 16 experts but exponentially gets worse as experts increase such that at >100 experts it is essentially unusable. For the Marlin kernel, a single marlin_gemm_moe would launching num_experts CUDA kernels at least, while the fused_moe triton kernel only needs to launch one cuda kernel. This makes the Marlin kernel significantly slower than the fused_moe triton kernel. We will improve that marlin kernel soon!

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't see obvious wrong things. Is it possible to add some unittest?

@mergify mergify bot added the ci/build label Apr 9, 2025
Copy link
Collaborator

@luccafong luccafong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for adding the integration!

@DarkLight1337 DarkLight1337 merged commit c70cf0f into vllm-project:main Apr 10, 2025
44 of 45 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants