-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Kernel] moe wna16 cuda kernel #13321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Jinzhen Lin <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
I merged this with the main branch but received the following error with the startup command, doesn't happen without this PR. Without enforce eager there's also an error but with cuBLAS, I didn't save that error messsage. Sorry for the messed up error log, it has some race conditions with all the other processes spawned by tensor parallel. python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 12345 --max-model-len 65536 --trust-remote-code --tensor-parallel-size 8 --quantization moe_wna16 --gpu-memory-utilization 0.99 --served-model-name gpt-4-1106-preview --model /root/data/llms/DeepSeek-R1-AWQ --distributed-executor-backend ray --enforce-eager ERROR 02-15 10:48:11 engine.py:139] Traceback (most recent call last):
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 137, in start
ERROR 02-15 10:48:11 engine.py:139] self.run_engine_loop()
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 200, in run_engine_loop
ERROR 02-15 10:48:11 engine.py:139] request_outputs = self.engine_step()
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 218, in engine_step
ERROR 02-15 10:48:11 engine.py:139] raise e
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 209, in engine_step
ERROR 02-15 10:48:11 engine.py:139] return self.engine.step()
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 1386, in step
ERROR 02-15 10:48:11 engine.py:139] outputs = self.model_executor.execute_model(
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/executor/ray_distributed_executor.py", line 408, in execute_model
ERROR 02-15 10:48:11 engine.py:139] return super().execute_model(execute_model_req)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 284, in execute_model
ERROR 02-15 10:48:11 engine.py:139] driver_outputs = self._driver_execute_model(execute_model_req)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/executor/ray_distributed_executor.py", line 401, in _driver_execute_model
ERROR 02-15 10:48:11 engine.py:139] return self.driver_worker.execute_method("execute_model",
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 580, in execute_method
ERROR 02-15 10:48:11 engine.py:139] raise e
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 571, in execute_method
ERROR 02-15 10:48:11 engine.py:139] return run_method(target, method, args, kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/utils.py", line 2196, in run_method
ERROR 02-15 10:48:11 engine.py:139] return func(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 418, in execute_model
ERROR 02-15 10:48:11 engine.py:139] output = self.model_runner.execute_model(
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 02-15 10:48:11 engine.py:139] return func(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1718, in execute_model
ERROR 02-15 10:48:11 engine.py:139] hidden_or_intermediate_states = model_executable(
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139] return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139] return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 677, in forward
ERROR 02-15 10:48:11 engine.py:139] hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 172, in __call__
ERROR 02-15 10:48:11 engine.py:139] return self.forward(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 633, in forward
ERROR 02-15 10:48:11 engine.py:139] hidden_states, residual = layer(positions, hidden_states,
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139] return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139] return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 560, in forward
ERROR 02-15 10:48:11 engine.py:139] hidden_states = self.mlp(hidden_states)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139] return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139] return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 162, in forward
ERROR 02-15 10:48:11 engine.py:139] final_hidden_states = self.experts(
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139] return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139] return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 586, in forward
ERROR 02-15 10:48:11 engine.py:139] final_hidden_states = self.quant_method.apply(
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/moe_wna16.py", line 311, in apply
ERROR 02-15 10:48:11 engine.py:139] return fused_experts(x,
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1203, in fused_experts
ERROR 02-15 10:48:11 engine.py:139] torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 02-15 10:48:11 engine.py:139] return self._op(*args, **(kwargs or {}))
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1105, in inplace_fused_experts
ERROR 02-15 10:48:11 engine.py:139] fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1364, in fused_experts_impl
ERROR 02-15 10:48:11 engine.py:139] ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/vllm/_custom_ops.py", line 988, in moe_sum
ERROR 02-15 10:48:11 engine.py:139] torch.ops._moe_C.moe_sum(input, output)
ERROR 02-15 10:48:11 engine.py:139] File "/usr/local/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 02-15 10:48:11 engine.py:139] return self._op(*args, **(kwargs or {}))
ERROR 02-15 10:48:11 engine.py:139] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] RuntimeError: CUDA error: an illegal memory access was encountered |
Signed-off-by: Jinzhen Lin <[email protected]>
@LagPixelLOL Fixed. Note that the test result is tested without MLA. |
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
can you show you run command with vllm ? this is my command: |
@sunjianxide try to run with This PR only optimizes moe gemm operator, if the inference bottleneck is not in moe gemm, the overall performance improvement will be relatively low. My command
|
@mgoin Could you review this PR. I have test it in the last two weeks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work keeping the kernel clean! I do think we definitely need to refactor fused_moe.py to give a better interface to plug in multiple kernels - similar to vllm/model_executor/layers/quantization/kernels/
for Linear modules
Just a few questions
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
a7df800
to
7287aa0
Compare
Signed-off-by: Jinzhen Lin <[email protected]>
2b4becc
to
be03899
Compare
Signed-off-by: Jinzhen Lin <[email protected]>
be03899
to
38a0eed
Compare
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
if (blockIdx.z == 0 && offset_n < size_n)
output[token_index * size_n + offset_n] = Dtype::int2num(0); atomicAdd(&output[token_index * size_n + offset_n],
Dtype::float2num(res[m])); Hi @jinzhen-lin , the writes to output in these two places cannot guarantee the order of execution, which may lead to potential correctness issues. Should we move the initialization of output outside the kernel, for example by using |
I'm not entirely sure about the execution order issue, but I've observed that both the native GPTQ GEMM and the exllama GPTQ GEMM use similar logic. Therefore, I assume that the execution finish time for blocks with Have you noticed any numerical anomalies? The additional |
Signed-off-by: Jinzhen Lin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work, thank you
In vllm 0.8.0, I started the deepseek-r1-awq model with the following command: The torch._dynamo.exc.InternalTorchDynamoError error is reported. The error is as follows:
Machine Information @jinzhen-lin Please consult your debugging hardware and driver environment, thank you very much. |
try run with |
Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: mgoin <[email protected]>
#12185 add triton moe wna16 kernel, but triton cannot reach the best performance when m is small. This PR add the moe wna16 cuda kernel. It have better generation speed than triton version when
num_seqs * top_k / num_experts
is small (for deepseek-v3/deepseek-r1, it have better performance when num_seqs <= 256). This PR also provides a better block config for triton moe wna16 gemm kernel.The generation speed on 8*A100 + deepseek-v3-awq (tokens/s):
Note:
group_size
may slightly effect generation speed.atomicAdd
of bf16. It should have similar speed with fp16 on sm90+ device.export VLLM_MLA_DISABLE=1
.