Skip to content

Conversation

jinzhen-lin
Copy link
Contributor

@jinzhen-lin jinzhen-lin commented Apr 18, 2025

This PR optimizes dense marlin kernel and moe marlin kernel.

Summary:

  • (dense marlin only) Migrate the optimization method introduced for the moe marlin kernel in [Kernel] moe wna16 marlin kernel #14447 to the dense marlin kernel. Including
    • By modifying the workspace usage logic, the limitation of max_par was removed, accelerating the speed of large batches ( m > 1024)
    • Simulated the m8n16k16 MMA instruction using the m16n8k16 instruction via transposition. This optimize the performance for m <= 8.
    • For AWQ model, Fused mul(sub(quantized_weight, zero_points), scale) into fma(quantized_weight, scale, -mul(zero_points * scale)), where -mul(zero_points * scale) can be precomputed. This save some Floating Point Operations.
    • Remove some unused kernel to reduce wheel size.
    • Split the kernel into multiple files to speed up the compilation.
    • etc.
  • (moe marlin only) Optimize the index calculation logic when reading A, caching row and column information as much as possible, this achieve a performance improvement of up to 10%.
  • (moe marlin only) Make use of the available shared memory to cache matrix A as much as possible, when the same threadblock processes the same M but different N, it can reduce the IOs for A.
  • FP8 marlin. Now we can run DeepSeek with W-FP8-A-FP16.
    • Merge fp8_marlin into gptq_marlin kernel, and add block quant support.
    • Add fp8 support for moe marlin

Signed-off-by: Jinzhen Lin <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Apr 18, 2025
Signed-off-by: Jinzhen Lin <[email protected]>
@mgoin mgoin requested a review from LucasWilkinson April 18, 2025 16:04
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented Apr 19, 2025

dense marlin benchmark tests (on A800)

image

image

image

@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented Apr 19, 2025

moe marlin benchmark tests (on A800)

(NOTE1: The optimization methods introduced in this PR have already been implemented in #14447 for cases where k <= 256, resulting in limited performance improvement under such conditions.)

(NOTE2:The "main" section in the following results is inconsistent with the ones posted inhttps://github.com//pull/14447, because after posting the benchmark results in #14447, I made several rounds of optimizations.)

shapes of DeepSeek-V3-AWQ (with TP=8)

image

shapes of Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4 (with TP=1)

image

shapes of Mixtral-8x7B-Instruct-v0.1-AWQ (with TP=1)

image

@jinzhen-lin
Copy link
Contributor Author

@mgoin @LucasWilkinson

The benchmark results is posted.

BTW, should we change the default value of VLLM_MARLIN_USE_ATOMIC_ADD to 1 now ? (still don't sure if this would cause some bugs though, see #14138 )

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does increase the wheel size by about 10MB to 313MB, so we should try to trim down a bit.

[2025-04-22T07:40:33Z] #32 0.707 Wheel dist/vllm-0.8.5.dev150+gfb8563602-cp38-abi3-linux_x86_64.whl is within the allowed size (313.19 MB).

I think there may be some compiled function overlap that I uncovered during review.

Comment on lines -395 to -405
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we actually support HQQ for MoE before?

Copy link
Contributor Author

@jinzhen-lin jinzhen-lin Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Marlin template support is_zp_float = true (HQQ), but I don't enable it.

Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin
Copy link
Contributor Author

@mgoin The remaining failed tests seems not related to this PR.

Copy link

mergify bot commented May 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jinzhen-lin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 2, 2025
@mergify mergify bot removed the needs-rebase label May 2, 2025
@mgoin
Copy link
Member

mgoin commented May 3, 2025

Looks like several of the failing tests are related to the merge 😞

[2025-05-02T21:38:44Z] FAILED kernels/quantization/test_awq_marlin.py::test_fused_marlin_moe_awq[128-6-64-1024-2048-64] - RuntimeError: vllm::fused_marlin_moe() is missing value for argument 'quant_type_id'. Declaration: vllm::fused_marlin_moe(Tensor hidden_states, Tensor w1, Tensor w2, Tensor w1_scale, Tensor w2_scale, Tensor gating_output, Tensor topk_weights, Tensor topk_ids, SymInt quant_type_id, SymInt global_num_experts=-1, Tensor? expert_map=None, Tensor? g_idx1=None, Tensor? g_idx2=None, Tensor? sort_indices1=None, Tensor? sort_indices2=None, Tensor? w1_zeros=None, Tensor? w2_zeros=None, Tensor? workspace=None, bool is_k_full=True, bool inplace=False) -> Tensor

@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented May 4, 2025

Looks like several of the failing tests are related to the merge 😞

[2025-05-02T21:38:44Z] FAILED kernels/quantization/test_awq_marlin.py::test_fused_marlin_moe_awq[128-6-64-1024-2048-64] - RuntimeError: vllm::fused_marlin_moe() is missing value for argument 'quant_type_id'. Declaration: vllm::fused_marlin_moe(Tensor hidden_states, Tensor w1, Tensor w2, Tensor w1_scale, Tensor w2_scale, Tensor gating_output, Tensor topk_weights, Tensor topk_ids, SymInt quant_type_id, SymInt global_num_experts=-1, Tensor? expert_map=None, Tensor? g_idx1=None, Tensor? g_idx2=None, Tensor? sort_indices1=None, Tensor? sort_indices2=None, Tensor? w1_zeros=None, Tensor? w2_zeros=None, Tensor? workspace=None, bool is_k_full=True, bool inplace=False) -> Tensor

@mgoin The error seems introduced by rebase. FIxed now (The content of test_awq_marlin.py is test cases for moe, it should be removed, the moe marlin test cases are already in test_moe.py).

@simon-mo simon-mo merged commit 1d0c9d6 into vllm-project:main May 5, 2025
77 of 80 checks passed
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented May 12, 2025

Hey @jinzhen-lin @mgoin - it looks like this PR may have broken marlin utilities for a cuple integrations. See the latest nightly runs:

Would you mind taking a peek at resolving?

@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented May 12, 2025

Hey @jinzhen-lin @mgoin - it looks like this PR may have broken marlin for FBGEMM integration

Would you mind taking a peek at resolving?

I will fix it later.

@robertgshaw2-redhat
Copy link
Collaborator

Hey @jinzhen-lin @mgoin - it looks like this PR may have broken marlin for FBGEMM integration

Would you mind taking a peek at resolving?

I will fix it later.

Thank you. I posted in the comment 2 failures that I see

@mgoin
Copy link
Member

mgoin commented May 12, 2025

I've resolved most of the model issues with the above referenced PRs #18002 and #18017 .

There is one outstanding issue that it would be useful to have you take a look at @jinzhen-lin. Regarding the weight loading buildkite test, there is this failing case with mixtral w8a16 group=128 desc_act=True

=== FAILED MODEL: gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True ===

Locally I've been able to trigger the failure with this command which fails in the moe_wna16_marlin_gemm call:

CUDA_LAUNCH_BLOCKING=1 vllm serve TheBloke/Mixtral-8x7B-v0.1-GPTQ -tp 2 --load-format dummy --enforce-eager --revision gptq-8bit-128g-actorder_True
...
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]   File "/home/mgoin/code/vllm/vllm/model_executor/layers/quantization/gptq_marlin.py", line 630, in apply
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]     return torch.ops.vllm.fused_marlin_moe(
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]   File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]     return self._op(*args, **(kwargs or {}))
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]   File "/home/mgoin/code/vllm/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py", line 168, in fused_marlin_moe
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]     intermediate_cache3 = ops.moe_wna16_marlin_gemm(
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]   File "/home/mgoin/code/vllm/vllm/_custom_ops.py", line 1401, in moe_wna16_marlin_gemm
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]     return torch.ops._moe_C.moe_wna16_marlin_gemm(
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]   File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]     return self._op(*args, **(kwargs or {}))
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522] RuntimeError: CUDA error: an illegal memory access was encountered
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522] 
(VllmWorker rank=0 pid=3000030) ERROR 05-12 19:38:02 [multiproc_executor.py:522] 

mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants