Skip to content

Conversation

jinzhen-lin
Copy link
Contributor

@jinzhen-lin jinzhen-lin commented Feb 15, 2025

#12185 add triton moe wna16 kernel, but triton cannot reach the best performance when m is small. This PR add the moe wna16 cuda kernel. It have better generation speed than triton version when num_seqs * top_k / num_experts is small (for deepseek-v3/deepseek-r1, it have better performance when num_seqs <= 256). This PR also provides a better block config for triton moe wna16 gemm kernel.

The generation speed on 8*A100 + deepseek-v3-awq (tokens/s):

bs triton(main) cuda (fp16) (pr) cuda (bf16) (pr)
1 34.1 48.1 141.06% 45.2
2 43.9 80.7 183.83% 78.1
4 66.6 134.9 202.55% 131.4
8 93.7 200.5 213.98% 199.7
12 104.8 230.6 220.04% 225.8
16 125.1 291.4 232.93% 285.7
24 150.7 327.8 217.52% 325.8
32 171.9 377.3 219.49% 374.1
48 222.4 449.8 202.25% 442.3
64 262.7 495.7 188.69% 491.6

Note:

  1. The generation speed is copied from the vllm log.
  2. group_size may slightly effect generation speed.
  3. bf16 is slower than fp16 since A100(sm80) doesn't natively support atomicAdd of bf16. It should have similar speed with fp16 on sm90+ device.
  4. The test result is tested without MLA. export VLLM_MLA_DISABLE=1.

Signed-off-by: Jinzhen Lin <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Feb 15, 2025
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
@LagPixelLOL
Copy link

LagPixelLOL commented Feb 15, 2025

I merged this with the main branch but received the following error with the startup command, doesn't happen without this PR. Without enforce eager there's also an error but with cuBLAS, I didn't save that error messsage. Sorry for the messed up error log, it has some race conditions with all the other processes spawned by tensor parallel.

python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 12345 --max-model-len 65536 --trust-remote-code --tensor-parallel-size 8 --quantization moe_wna16 --gpu-memory-utilization 0.99 --served-model-name gpt-4-1106-preview --model /root/data/llms/DeepSeek-R1-AWQ --distributed-executor-backend ray --enforce-eager
ERROR 02-15 10:48:11 engine.py:139] Traceback (most recent call last):
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 137, in start
ERROR 02-15 10:48:11 engine.py:139]     self.run_engine_loop()
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 200, in run_engine_loop
ERROR 02-15 10:48:11 engine.py:139]     request_outputs = self.engine_step()
ERROR 02-15 10:48:11 engine.py:139]                       ^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 218, in engine_step
ERROR 02-15 10:48:11 engine.py:139]     raise e
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 209, in engine_step
ERROR 02-15 10:48:11 engine.py:139]     return self.engine.step()
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 1386, in step
ERROR 02-15 10:48:11 engine.py:139]     outputs = self.model_executor.execute_model(
ERROR 02-15 10:48:11 engine.py:139]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/executor/ray_distributed_executor.py", line 408, in execute_model
ERROR 02-15 10:48:11 engine.py:139]     return super().execute_model(execute_model_req)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 284, in execute_model
ERROR 02-15 10:48:11 engine.py:139]     driver_outputs = self._driver_execute_model(execute_model_req)
ERROR 02-15 10:48:11 engine.py:139]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/executor/ray_distributed_executor.py", line 401, in _driver_execute_model
ERROR 02-15 10:48:11 engine.py:139]     return self.driver_worker.execute_method("execute_model",
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 580, in execute_method
ERROR 02-15 10:48:11 engine.py:139]     raise e
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 571, in execute_method
ERROR 02-15 10:48:11 engine.py:139]     return run_method(target, method, args, kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/utils.py", line 2196, in run_method
ERROR 02-15 10:48:11 engine.py:139]     return func(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 418, in execute_model
ERROR 02-15 10:48:11 engine.py:139]     output = self.model_runner.execute_model(
ERROR 02-15 10:48:11 engine.py:139]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 02-15 10:48:11 engine.py:139]     return func(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1718, in execute_model
ERROR 02-15 10:48:11 engine.py:139]     hidden_or_intermediate_states = model_executable(
ERROR 02-15 10:48:11 engine.py:139]                                     ^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139]     return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139]     return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 677, in forward
ERROR 02-15 10:48:11 engine.py:139]     hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 02-15 10:48:11 engine.py:139]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 172, in __call__
ERROR 02-15 10:48:11 engine.py:139]     return self.forward(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 633, in forward
ERROR 02-15 10:48:11 engine.py:139]     hidden_states, residual = layer(positions, hidden_states,
ERROR 02-15 10:48:11 engine.py:139]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139]     return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139]     return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 560, in forward
ERROR 02-15 10:48:11 engine.py:139]     hidden_states = self.mlp(hidden_states)
ERROR 02-15 10:48:11 engine.py:139]                     ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139]     return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139]     return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_v2.py", line 162, in forward
ERROR 02-15 10:48:11 engine.py:139]     final_hidden_states = self.experts(
ERROR 02-15 10:48:11 engine.py:139]                           ^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 02-15 10:48:11 engine.py:139]     return self._call_impl(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 02-15 10:48:11 engine.py:139]     return forward_call(*args, **kwargs)
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 586, in forward
ERROR 02-15 10:48:11 engine.py:139]     final_hidden_states = self.quant_method.apply(
ERROR 02-15 10:48:11 engine.py:139]                           ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/moe_wna16.py", line 311, in apply
ERROR 02-15 10:48:11 engine.py:139]     return fused_experts(x,
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1203, in fused_experts
ERROR 02-15 10:48:11 engine.py:139]     torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 02-15 10:48:11 engine.py:139]     return self._op(*args, **(kwargs or {}))
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1105, in inplace_fused_experts
ERROR 02-15 10:48:11 engine.py:139]     fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 1364, in fused_experts_impl
ERROR 02-15 10:48:11 engine.py:139]     ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/vllm/_custom_ops.py", line 988, in moe_sum
ERROR 02-15 10:48:11 engine.py:139]     torch.ops._moe_C.moe_sum(input, output)
ERROR 02-15 10:48:11 engine.py:139]   File "/usr/local/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 02-15 10:48:11 engine.py:139]     return self._op(*args, **(kwargs or {}))
ERROR 02-15 10:48:11 engine.py:139]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-15 10:48:11 engine.py:139] RuntimeError: CUDA error: an illegal memory access was encountered

Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin
Copy link
Contributor Author

@LagPixelLOL Fixed. Note that the test result is tested without MLA.

@mgoin mgoin self-requested a review February 17, 2025 00:25
@sunjianxide
Copy link

can you show you run command with vllm ?
I run on on 8*A800 + deepseek-r1-awq ,but the speedup ratio typically ranges from 1.05-1.40

this is my command:
vllm serve /shared/weights/DeepSeek-R1-AWQ --host 0.0.0.0 --port 12345 --max-model-len 65536 --trust-remote-code --tensor-parallel-size 8 --quantization moe_wna16 --gpu-memory-utilization 0.9 --served-model-name deepseek-chat

@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented Feb 18, 2025

@sunjianxide try to run with VLLM_MLA_DISABLE=1 and reduce max_model_len, increase gpu_memory_utilization.

This PR only optimizes moe gemm operator, if the inference bottleneck is not in moe gemm, the overall performance improvement will be relatively low.

My command

VLLM_MLA_DISABLE=1 python -m vllm.entrypoints.openai.api_server --served-model-name model --model /root/DeepSeek-R1-AWQ/ --tensor-parallel-size 8 --trust-remote-code --gpu-memory-utilization 0.98 --max-model-len 32768 --max-num-seqs 64 --quantization moe_wna16  --dtype half

@jinzhen-lin
Copy link
Contributor Author

@mgoin Could you review this PR. I have test it in the last two weeks.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work keeping the kernel clean! I do think we definitely need to refactor fused_moe.py to give a better interface to plug in multiple kernels - similar to vllm/model_executor/layers/quantization/kernels/ for Linear modules

Just a few questions

@jinzhen-lin jinzhen-lin force-pushed the moe_wna16_cuda_kernel branch from a7df800 to 7287aa0 Compare February 28, 2025 05:41
Signed-off-by: Jinzhen Lin <[email protected]>
@jinzhen-lin jinzhen-lin force-pushed the moe_wna16_cuda_kernel branch from 2b4becc to be03899 Compare February 28, 2025 16:03
@jinzhen-lin jinzhen-lin force-pushed the moe_wna16_cuda_kernel branch from be03899 to 38a0eed Compare February 28, 2025 16:10
@jinzhen-lin
Copy link
Contributor Author

@mgoin Can we merge this? Though #14447 provide a better kernel on sm80+ devices, this kernel can be used on sm70/sm75 devices.

@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed quantization labels Mar 9, 2025
@GreyZzzzzzXh
Copy link

      if (blockIdx.z == 0 && offset_n < size_n)
        output[token_index * size_n + offset_n] = Dtype::int2num(0);
      atomicAdd(&output[token_index * size_n + offset_n],
                Dtype::float2num(res[m]));

Hi @jinzhen-lin , the writes to output in these two places cannot guarantee the order of execution, which may lead to potential correctness issues. Should we move the initialization of output outside the kernel, for example by using C.zero_()?

@jinzhen-lin
Copy link
Contributor Author

      if (blockIdx.z == 0 && offset_n < size_n)
        output[token_index * size_n + offset_n] = Dtype::int2num(0);
      atomicAdd(&output[token_index * size_n + offset_n],
                Dtype::float2num(res[m]));

Hi @jinzhen-lin , the writes to output in these two places cannot guarantee the order of execution, which may lead to potential correctness issues. Should we move the initialization of output outside the kernel, for example by using C.zero_()?

I'm not entirely sure about the execution order issue, but I've observed that both the native GPTQ GEMM and the exllama GPTQ GEMM use similar logic. Therefore, I assume that the execution finish time for blocks with blockIdx.x > 0 will not precede the initialization time of the block with blockIdx.x == 0.

Have you noticed any numerical anomalies? The additional c.zero_() may significantly increase the runtime, so I hope to avoid it if possible.

Signed-off-by: Jinzhen Lin <[email protected]>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work, thank you

@mgoin mgoin merged commit 90e88ab into vllm-project:main Mar 11, 2025
57 checks passed
@lizongyao123
Copy link

In vllm 0.8.0, I started the deepseek-r1-awq model with the following command:
VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_MLA_DISABLE=1 python -m vllm.entrypoints.openai.api_server --served-model-name model --model /home/work/transformer_models/DeepSeek-V3-AWQ/ --tensor-parallel-size 8 --trust-remote-code --gpu-memory-utilization 0.98 --max-model-len 32768 --max-num-seqs 64 --quantization moe_wna16 --dtype half --host 0.0.0.0 --port 12345

The torch._dynamo.exc.InternalTorchDynamoError error is reported. The error is as follows:

(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] WorkerProc hit an exception: %s
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] Traceback (most recent call last):
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 371, in worker_busy_loop
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     output = func(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return func(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 157, in determine_available_memory
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     self.model_runner.profile_run()
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1401, in profile_run
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hidden_states = self._dummy_run(self.max_num_tokens)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return func(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1265, in _dummy_run
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hidden_states = model(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return forward_call(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 688, in forward
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/compilation/decorators.py", line 238, in __call__
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     output = self.compiled_callable(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return fn(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return self._torchdynamo_orig_callable(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return _compile(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1036, in _compile
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     raise InternalTorchDynamoError(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     guarded_code = compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return _compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return function(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 794, in _compile_inner
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hook_output = hook(code, out_code)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/compilation/wrapper.py", line 115, in bytecode_hook
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     raise RuntimeError(msg)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] def forward(self, input_ids, positions, intermediate_tensors, inputs_embeds):
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     graph_out_0 = __compiled_fn_1(__import_torch_dot__dynamo_dot_utils.
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         call_size(input_ids, 0), input_ids, self._modules['embed_tokens'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _parameters['weight'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'input_layernorm']._parameters['weight'], self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_proj']._parameters[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'qweight'], self._modules['layers']._modules['0']._modules['self_attn']
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         ._modules['q_a_proj']._parameters['scales'], self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_proj']._parameters[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'qzeros'], self._modules['layers']._modules['0']._modules['self_attn'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['q_a_proj']._parameters['g_idx'], self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_proj']._parameters[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'g_idx_sort_indices'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'self_attn']._modules['q_a_proj'].workspace, self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_layernorm'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _parameters['weight'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'self_attn']._modules['q_b_proj']._parameters['qweight'], self._modules
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         ['layers']._modules['0']._modules['self_attn']._modules['q_b_proj'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _parameters['scales'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'self_attn']._modules['q_b_proj']._parameters['qzeros'], self._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'layers']._modules['0']._modules['self_attn']._modules['q_b_proj'].

Machine Information
A800 * 8
NVIDIA-SMI 525.125.06
Driver Version: 525.125.06
CUDA Version: 12.2
nvidia-cusparselt-cu12-0.6.2
pytorch-quantization 2.1.2
torch 2.6.0
torch-tensorrt 0.0.0
torchaudio 2.6.0
torchdata 0.7.0a0
torchtext 0.16.0a0
torchvision 0.21.0

@jinzhen-lin Please consult your debugging hardware and driver environment, thank you very much.

@jinzhen-lin
Copy link
Contributor Author

In vllm 0.8.0, I started the deepseek-r1-awq model with the following command: VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_MLA_DISABLE=1 python -m vllm.entrypoints.openai.api_server --served-model-name model --model /home/work/transformer_models/DeepSeek-V3-AWQ/ --tensor-parallel-size 8 --trust-remote-code --gpu-memory-utilization 0.98 --max-model-len 32768 --max-num-seqs 64 --quantization moe_wna16 --dtype half --host 0.0.0.0 --port 12345

The torch._dynamo.exc.InternalTorchDynamoError error is reported. The error is as follows:

(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] WorkerProc hit an exception: %s
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] Traceback (most recent call last):
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 371, in worker_busy_loop
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     output = func(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return func(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 157, in determine_available_memory
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     self.model_runner.profile_run()
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1401, in profile_run
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hidden_states = self._dummy_run(self.max_num_tokens)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return func(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1265, in _dummy_run
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hidden_states = model(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return forward_call(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 688, in forward
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/compilation/decorators.py", line 238, in __call__
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     output = self.compiled_callable(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return fn(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return self._torchdynamo_orig_callable(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return _compile(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1036, in _compile
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     raise InternalTorchDynamoError(
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     guarded_code = compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return _compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     return function(*args, **kwargs)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 794, in _compile_inner
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     hook_output = hook(code, out_code)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]   File "/usr/local/lib/python3.10/dist-packages/vllm/compilation/wrapper.py", line 115, in bytecode_hook
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     raise RuntimeError(msg)
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375] def forward(self, input_ids, positions, intermediate_tensors, inputs_embeds):
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]     graph_out_0 = __compiled_fn_1(__import_torch_dot__dynamo_dot_utils.
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         call_size(input_ids, 0), input_ids, self._modules['embed_tokens'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _parameters['weight'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'input_layernorm']._parameters['weight'], self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_proj']._parameters[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'qweight'], self._modules['layers']._modules['0']._modules['self_attn']
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         ._modules['q_a_proj']._parameters['scales'], self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_proj']._parameters[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'qzeros'], self._modules['layers']._modules['0']._modules['self_attn'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['q_a_proj']._parameters['g_idx'], self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_proj']._parameters[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'g_idx_sort_indices'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'self_attn']._modules['q_a_proj'].workspace, self._modules['layers'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _modules['0']._modules['self_attn']._modules['q_a_layernorm'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _parameters['weight'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'self_attn']._modules['q_b_proj']._parameters['qweight'], self._modules
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         ['layers']._modules['0']._modules['self_attn']._modules['q_b_proj'].
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         _parameters['scales'], self._modules['layers']._modules['0']._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'self_attn']._modules['q_b_proj']._parameters['qzeros'], self._modules[
(VllmWorker rank=3 pid=16239) ERROR 03-19 10:02:05 [multiproc_executor.py:375]         'layers']._modules['0']._modules['self_attn']._modules['q_b_proj'].

Machine Information A800 * 8 NVIDIA-SMI 525.125.06 Driver Version: 525.125.06 CUDA Version: 12.2 nvidia-cusparselt-cu12-0.6.2 pytorch-quantization 2.1.2 torch 2.6.0 torch-tensorrt 0.0.0 torchaudio 2.6.0 torchdata 0.7.0a0 torchtext 0.16.0a0 torchvision 0.21.0

@jinzhen-lin Please consult your debugging hardware and driver environment, thank you very much.

try run with VLLM_USE_V1=0 or use v0.7.3

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Jinzhen Lin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants