Skip to content

Conversation

@billishyahao
Copy link
Contributor

@billishyahao billishyahao commented Sep 28, 2025

Purpose

Previously, the logic for generating batch_size_capture_list always included [1, 2, 4] by default. Refer to

if len(cuda_graph_sizes) == 1:
batch_size_capture_list = [1, 2, 4] + [
i for i in range(8, cuda_graph_sizes[0] + 1, 8)
]

This was misleading when cuda_graph_sizes < 4 because the list contained batch sizes that exceeded the actual maximum.

How to fix in this patch

Filtered out [1, 2, 4] values that are larger than cuda_graph_sizes[0].
The list now correctly reflects the allowed batch sizes, even for small values.

Examples:

  • cuda_graph_sizes = [1] → [1]
  • cuda_graph_sizes = [2] → [1, 2]
  • cuda_graph_sizes = [17] → [1, 2, 4, 8, 16]

Test Result

Before fix:

vllm serve /models/DeepSeek-R1 --trust-remote-code --tensor-parallel-size 8 --cuda-graph-sizes 2

(EngineCore_DP0 pid=904192) INFO 09-28 02:11:36 [core.py:78] Initializing a V1 LLM engine (v0.11.0rc2.dev34+gda63274d9) with config: model='/models/DeepSeek-R1', speculative_config=None, tokenizer='/models/DeepSeek-R1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=163840, download_dir=None, load_format=auto, tensor_parallel_size=8, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=/models/DeepSeek-R1, enable_prefix_caching=True, chunked_prefill_enabled=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":null,"cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention","vllm.plamo2_mamba_mixer","vllm.gdn_attention"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":[2,1],"use_cudagraph":true,"cudagraph_num_of_warmups":1,
"cudagraph_capture_sizes":[4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":true,"use_inductor_graph_partition":false,"pass_config":{},"max_capture_size":4,"local_cache_dir":null}

After fix:

vllm serve /models/DeepSeek-R1 --trust-remote-code --tensor-parallel-size 8 --cuda-graph-sizes 2

(EngineCore_DP0 pid=885683) INFO 09-28 01:52:41 [core.py:78] Initializing a V1 LLM engine (v0.11.0rc2.dev34+gda63274d9) with config: model='/models/DeepSeek-R1', speculative_config=None, tokenizer='/models/DeepSeek-R1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=163840, download_dir=None, load_format=auto, tensor_parallel_size=8, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=/models/DeepSeek-R1, enable_prefix_caching=True, chunked_prefill_enabled=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":null,"cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention","vllm.plamo2_mamba_mixer","vllm.gdn_attention"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":[2,1],"use_cudagraph":true,"cudagraph_num_of_warmups":1,
"cudagraph_capture_sizes":[2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":true,"use_inductor_graph_partition":false,"pass_config":{},"max_capture_size":2,"local_cache_dir":null}

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses an issue where the batch_size_capture_list for CUDA graphs was generated with misleading values when the maximum graph size was small. The fix is clear and effective, filtering the default batch sizes [1, 2, 4] against the specified maximum. The introduction of the max_graph_size variable improves readability. The change is well-contained and the logic is sound. I have no further suggestions.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Signed-off-by: billishyahao <[email protected]>
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 30, 2025
@ProExpertProg ProExpertProg merged commit 2518230 into vllm-project:main Oct 1, 2025
42 checks passed
@mgoin
Copy link
Member

mgoin commented Oct 1, 2025

Actually I would expect cuda_graph_sizes = [17] → [1, 2, 4, 8, 16] to produce cuda_graph_sizes = [17] → [1, 2, 4, 8, 16, 17]. What do you think @billishyahao @yewentao256 @ProExpertProg ?

@ProExpertProg
Copy link
Collaborator

Yeah I think that would be fine too feel free to ask for it in #26016

@billishyahao
Copy link
Contributor Author

@mgoin Good catch. I am willing to adding this logic after #26016

pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
… 4 (#25829)

Signed-off-by: billishyahao <[email protected]>
Co-authored-by: Luka Govedic <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
… 4 (vllm-project#25829)

Signed-off-by: billishyahao <[email protected]>
Co-authored-by: Luka Govedic <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
… 4 (vllm-project#25829)

Signed-off-by: billishyahao <[email protected]>
Co-authored-by: Luka Govedic <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
… 4 (vllm-project#25829)

Signed-off-by: billishyahao <[email protected]>
Co-authored-by: Luka Govedic <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants