Skip to content

Conversation

SolitaryThinker
Copy link
Contributor

@SolitaryThinker SolitaryThinker commented Jul 31, 2024

Adds initial multi step scheduling support to vLLM.
RFC: #6854

Current Status:

8/16: Initial support for chunked prefill thanks to @varun-sundar-rabindranath

8/14: Ready for another round of reviews! please review #7452
8/8: multi-node working
8/6: PP+TP working; PP+ray fixed; a few single GPU perf regressions (easy fix)
8/2 PP works with MP; Ready for initial pass on design
8/1 - PP is very close to working. We do get the desired interleaving of steps between microbatches which is great!
7/31 - Current branch is in very rough shape after getting the RFC design working. Will clean up after adding TP/PP support as there may be some refactors needed. However single GPU is ready for initial testing

Cmd:
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --swap-space 16 --disable-log-requests --use-v2-block-manager --tensor-parallel-size 1 --worker-use-ray --pipeline-parallel-size 1 --gpu-memory-utilization 0.90 --num-scheduler-steps 8

Benchmark (8/16)
See: #7528
CP_1: Force Single Step: We force single step when there are prefill requests in a batch. This may work well for offline batching, but not good for online serving because new requests keep coming.

CP_2: Ignore Prefill (WIP): We ignore prefill requests since the second step, meaning that prefill requests do nothing in (k-1) steps. This may work better for online serving.

Single GPU Baseline (Req/s) Baseline+CP (Req/s) MS-8 (Req/s) MS-8+CP_1 (Req/s)
A10G 8B Llama 6.21 - 6.63 -
H100 8B Llama 25.96 27.82 44.44 31.4
H100 30B Llama 10.38 11.01 14.27 12.31
PP=2 Baseline (Req/s) MS-4 (Req/s) MS-8 (Req/s) MS-12 (Req/s) MS-16 (Req/s)
A10G 8B Llama (microbatch=128) 8.98 - 9.99 - -
H100 8B Llama 23 - 31 - - `
H100 70B Llama 3.09 3.13 3.13 - -
TP=2 Baseline (Req/s) MS-4 (Req/s) MS-8 (Req/s) MS-12 (Req/s) MS-16 (Req/s)
A10G 8B Llama 6.11 - 7.02 - -
TP=2, PP=2 Baseline (Req/s) MS-4 (Req/s) MS-8 (Req/s) MS-12 (Req/s) MS-16 (Req/s)
A10G 8B Llama (microbatch=128) 5.99 - 7.15 - -

TODO:
Milestone 1: POC

  • Add --max_forward_calls_per_step to cli argument, engine args, and schedulerConfig
  • Changes to SequenceGroupState in sequence.py to track multi-step state.
  • Add MultiStepWorker in worker/ to cache multi-step state
  • Changes to ModelRunner to handle multi step state
  • Reorganize input preparation in ModelRunner to reduce duplicate code
  • Async GPU->CPU transfer for sampled token
  • Async pythonization
  • Flash Attn backend
  • Cudagraph
  • Benchmarks (Ongoing)
  • TP
  • PP (works with MP and Ray, mem leak somewhere with RAY)
  • PP+TP
  • multi-node

Milstone 2: Mergeable

Follow up work: Tracking Issue #7528

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@rkooo567
Copy link
Collaborator

rkooo567 commented Aug 1, 2024

QQ: do you plan to split PRs to smaller pieces?

@SolitaryThinker
Copy link
Contributor Author

@rkooo567 If there are splits that makes sense I will definitely do that. Currently working on a small part here #6971

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first batch of comments.

@SolitaryThinker SolitaryThinker changed the title [WIP] [core] Multi Step Scheduling [core] Multi Step Scheduling Aug 9, 2024
@SolitaryThinker
Copy link
Contributor Author

@zhuohan123 @rkooo567 @Yard1 @comaniac @alexm-neuralmagic rebased and ready for review

@SolitaryThinker
Copy link
Contributor Author

Working on a smaller PR that contains parts of this.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of questions. Will add more tmrw.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second batch of reviews

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assert mean the MultiStepModelRunner can only be run with one step? Can you elaborate on this?

Copy link
Contributor Author

@SolitaryThinker SolitaryThinker Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiStepModelRunner only takes a single step internally before returning to AsyncLLMEngine. As the multi-step is done implicitly using stateful model inputs and SequenceGroup states.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explaination!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing tho. IIRC, this was introduced by me for multi-step draft model runner? We should remove this argument and use stateful model inputs as the unify representation. Also cc @alexm-neuralmagic

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's remove this argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this in a follow up PR as it will involve spec decode as well. Will add to TODO tracker

Copy link
Contributor

@afeldman-nm afeldman-nm Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SolitaryThinker no need to block on this feedback - but if you have time - I would propose adding an example/offline_inference_multi_step.py example which instantiates an engine instance with multi-step enabled. Similar in structure to example/offline_inference.py.

An example of why this is useful - as part of the logprobs workstream, I am trying to step through the multi-step model runner with the python debugger & examine the output logprobs. I am using your multi_step/test_correctness.py in order to set up a server with multi-step enabled.

However, multi_step/test_correctness.py is an end-to-end client/server test & it is not straightforward (although technically doable) to step through the server code with the debugger because the server is in another process.

I will get around this by writing a short script which sets up an engine instance with multi-step enabled.

However, for someone else who is approaching this code for the first time, it could be helpful to have an example file (or unit test) which just sets up an engine instance with multi-step enabled and invokes inference using LLM.generate(). This could be a good way to facilitate quick debugging & also gives insight into how the server works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the offline_inference_multi_step.py script I wrote for myself to facilitate debugging, if you would like to use it.

'''
Example of setting up LLM with multi-step enabled.
In actuality, async engine would be a more sensible choice
for a real use-case. However this example is useful
for demonstration & debugging of multi-step code.
'''

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="JackFram/llama-160m",
          swap_space=16,
          tensor_parallel_size=1,
          gpu_memory_utilization=0.9,
          num_scheduler_steps=8,
          use_v2_block_manager=True,
          )
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Comment on lines +299 to +301
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do wonder if there's a more generic way of doing this. If this data structure gets modified somewhere else it will not be reflected here. Maybe a loop where we check the device if the object is a tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are optionals and only set if include_gpu_probs_tensor is set in the sampler.

SolitaryThinker and others added 2 commits August 19, 2024 10:55
remove some redundant test cases

set v2 blockmananger and fix rebase

Update vllm/engine/async_llm_engine.py

Co-authored-by: Zhuohan Li <[email protected]>

Update vllm/engine/async_llm_engine.py

Co-authored-by: Zhuohan Li <[email protected]>

Update vllm/worker/multi_step_model_runner.py

Co-authored-by: Zhuohan Li <[email protected]>

add comment

typo

rename to StatefulModelInput

renamed outputs to cached_outputs

Update vllm/worker/multi_step_model_runner.py

Co-authored-by: afeldman-nm <[email protected]>
@SolitaryThinker
Copy link
Contributor Author

Also before merge, can you please verify the throughput (tokens/sec) gain in the following settings to make sure the PR is good performance-wise:

ShareGPT + Llama 8B + 1x H100/A100
ShareGPT + Llama 70B + 8x H100/A100
Also, can you add what are the dataset you are using in your original benchmark? Thanks!

@zhuohan123
I'm using sharegpt for all the numbers. Benchmarked using the benchmark_serving.py script.
See below for single GPU numbers.
image

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the hard work! Please make sure to keep track of the TODOs we discussed in this PR.

@Yard1 Yard1 merged commit 47b65a5 into vllm-project:main Aug 19, 2024
65 checks passed
@SolitaryThinker SolitaryThinker deleted the multi-step branch August 19, 2024 22:16
@WoosukKwon
Copy link
Collaborator

[rank0]: File "/data/woosuk/workspace/vllm/vllm/engine/output_processor/multi_step.py", line 88, in process_outputs
[rank0]: assert valid_samples

@SolitaryThinker Huge thanks for the PR! QQ: I got the above error when running benchmark scripts with num_scheduler_steps > 1. Is this expected?

@jiqing-feng
Copy link
Contributor

Hi @WoosukKwon . I see spec decode also has a class name MultiStepWorker, is there any relation with MultiStepWorker from vllm/worker/multi_step_worker.py in this PR?

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Co-authored-by: afeldman-nm <[email protected]>
Signed-off-by: LeiWang1999 <[email protected]>
@FerranAgulloLopez
Copy link

Hello:) Sorry to bother, what is the state of this? Is it planned for implementation in new v1 version? Thanks a lot!

@SolitaryThinker
Copy link
Contributor Author

SolitaryThinker commented Jul 23, 2025

My understanding is that this shouldn't be needed in v1, as this was a stopgap for the performance bottlenecks of the scheduler in v0. Someone more up-to-date can correct me though :)

@FerranAgulloLopez
Copy link

FerranAgulloLopez commented Jul 24, 2025

Iep:) Interesting, I see the gap still exists, but as you are saying may be way smaller than before.

My results in version 0.9.2 when using v1 scheduler with NVIDIA Nsight, batch size 128 and Llama-3.1-8B:

  • When just having decoding requests in the batch, aprox 1.2ms of gap vs 13ms of step ->
image
  • When having both decoding and prefill (chunked ones), aprox 1.4ms of gap vs 52ms of step ->
image

As you were saying, we may consider it negligible, or at least, there are other optimizations way more important than just this small optimization margin.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.