-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[V1] AsyncLLM data parallel #13923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] AsyncLLM data parallel #13923
Changes from 58 commits
9ca44ce
3f51611
d8c591e
65e225d
5ce57b6
a3f1102
a66fb01
67672c2
cf52fbf
a4ec81b
292aa00
4b62ffd
407c72e
31bf7ea
fde51ce
d5a3e68
6d89a1b
448abd9
a1e513e
50cf64c
f365998
b2571f0
32c6f24
9c30cd7
672d07e
dea382b
d24a626
cd03c80
f1004b7
d3298fa
74dde48
5fe1b75
648659f
119d1ec
55328ee
4f5330e
48770ec
d229f4d
2f91cc4
518047a
cb2b099
ff1137a
61f4fcb
44874c2
66fc582
7764466
51e8bf0
f692c12
47b5e1c
f226139
af47920
693c521
6e131e3
d9ac856
3abbdef
b18417e
56b2b78
05ab310
5295c34
4f897b8
62f32ed
771ccf1
05a0e83
bc41b13
ccecb42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import asyncio | ||
| import os | ||
| from contextlib import ExitStack | ||
| from typing import Optional | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm import SamplingParams | ||
| from vllm.engine.arg_utils import AsyncEngineArgs | ||
| from vllm.inputs import PromptType | ||
| from vllm.platforms import current_platform | ||
| from vllm.sampling_params import RequestOutputKind | ||
| from vllm.v1.engine.async_llm import AsyncLLM | ||
| from vllm.v1.engine.core_client import DPAsyncMPClient | ||
|
|
||
| if not current_platform.is_cuda(): | ||
| pytest.skip(reason="V1 currently only supported on CUDA.", | ||
| allow_module_level=True) | ||
|
|
||
|
|
||
| async def generate(engine: AsyncLLM, | ||
| request_id: str, | ||
| prompt: PromptType, | ||
| output_kind: RequestOutputKind, | ||
| max_tokens: int, | ||
| prompt_logprobs: Optional[int] = None) -> tuple[int, str]: | ||
| # Ensure generate doesn't complete too fast for cancellation test. | ||
| await asyncio.sleep(0.2) | ||
|
|
||
| count = 0 | ||
| sampling_params = SamplingParams(max_tokens=max_tokens, | ||
| ignore_eos=True, | ||
| output_kind=output_kind, | ||
| temperature=0, | ||
| prompt_logprobs=prompt_logprobs) | ||
| async for out in engine.generate(request_id=request_id, | ||
| prompt=prompt, | ||
| sampling_params=sampling_params): | ||
|
|
||
| num_tokens = len(out.outputs[0].token_ids) | ||
| if output_kind == RequestOutputKind.DELTA: | ||
| count += num_tokens | ||
| else: | ||
| count = num_tokens | ||
|
|
||
| await asyncio.sleep(0.) | ||
|
|
||
| return count, request_id | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) | ||
| @pytest.mark.asyncio | ||
| async def test_load(monkeypatch, output_kind: RequestOutputKind): | ||
| with monkeypatch.context() as m, ExitStack() as after: | ||
| m.setenv("VLLM_USE_V1", "1") | ||
|
||
|
|
||
| engine_args = AsyncEngineArgs( | ||
| model="ibm-research/PowerMoE-3b", | ||
| enforce_eager=True, | ||
| disable_log_requests=True, | ||
| tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), | ||
| data_parallel_size=int(os.getenv("DP_SIZE", 2)), | ||
| ) | ||
|
|
||
| prompt = "This is a test of data parallel" | ||
|
|
||
| engine = AsyncLLM.from_engine_args(engine_args) | ||
| after.callback(engine.shutdown) | ||
|
|
||
| NUM_REQUESTS = 100 | ||
| NUM_EXPECTED_TOKENS = 10 | ||
|
|
||
| request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] | ||
|
|
||
| # Create concurrent requests. | ||
| tasks = [] | ||
| for request_id in request_ids: | ||
| tasks.append( | ||
| asyncio.create_task( | ||
| generate(engine, request_id, prompt, output_kind, | ||
| NUM_EXPECTED_TOKENS))) | ||
|
|
||
| # Confirm that we got all the EXPECTED tokens from the requests. | ||
| done, pending = await asyncio.wait(tasks, | ||
| return_when=asyncio.FIRST_EXCEPTION) | ||
| for task in pending: | ||
| task.cancel() | ||
| for task in done: | ||
| num_generated_tokens, request_id = await task | ||
| assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( | ||
| f"{request_id} generated {num_generated_tokens} but " | ||
| f"expected {NUM_EXPECTED_TOKENS}") | ||
|
|
||
| assert not engine.output_processor.has_unfinished_requests() | ||
|
|
||
| # testing internals here which may break | ||
| core_client: DPAsyncMPClient = engine.engine_core | ||
| # the engines only synchronize stopping every N steps so | ||
| # allow a small amount of time here. | ||
| for _ in range(10): | ||
| if core_client.num_engines_running == 0: | ||
| break | ||
| await asyncio.sleep(0.5) | ||
| assert core_client.num_engines_running == 0 | ||
| assert not core_client.reqs_in_flight | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,13 +36,15 @@ def __init__( | |
| cache_config: CacheConfig, | ||
| lora_config: Optional[LoRAConfig], | ||
| speculative_config: Optional[SpeculativeConfig], | ||
| log_stats: bool, | ||
| structured_output_manager: StructuredOutputManager, | ||
| include_finished_set: bool = False, | ||
| log_stats: bool = False, | ||
| ) -> None: | ||
| self.scheduler_config = scheduler_config | ||
| self.cache_config = cache_config | ||
| self.lora_config = lora_config | ||
| self.speculative_config = speculative_config | ||
| self.include_finished_set = include_finished_set | ||
| self.log_stats = log_stats | ||
| self.structured_output_manager = structured_output_manager | ||
|
|
||
|
|
@@ -647,10 +649,16 @@ def update_from_output( | |
| new_running.append(request) | ||
|
|
||
| self.running = new_running | ||
| return EngineCoreOutputs( | ||
| engine_core_outputs = EngineCoreOutputs( | ||
| outputs=outputs, | ||
| scheduler_stats=self.make_stats(), | ||
| ) | ||
| if self.include_finished_set: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
| #TODO currently sending duplicates here, improve this | ||
| engine_core_outputs.finished_requests = ( | ||
| scheduler_output.finished_req_ids | self.finished_req_ids) | ||
|
|
||
| return engine_core_outputs | ||
|
|
||
| def add_request(self, request: Request) -> None: | ||
| self.waiting.append(request) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if DP works on TPU or AMD GPUs, but modify this reason string since V1 works there at least experimentally?
vllm/vllm/engine/arg_utils.py
Lines 1669 to 1675 in d0cfec7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could actually use
supports_v1now that this PR has landed (probably only want to turn tests on for CUDA and RoCM though)#15417