Skip to content

Commit befc402

Browse files
afeldman-nmnjhill
andauthored
[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)
Signed-off-by: Andrew Feldman <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 444b0f0 commit befc402

File tree

5 files changed

+641
-9
lines changed

5 files changed

+641
-9
lines changed

tests/v1/engine/test_llm_engine.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,114 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import random
4+
from typing import Dict, List, Optional, Tuple
5+
36
import pytest
47

58
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
69
from vllm import LLM, SamplingParams
710

11+
MODEL = "facebook/opt-125m"
12+
DTYPE = "half"
813

9-
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
10-
"""Test passes if LLMEngine raises an exception when it is configured
11-
for automatic prefix caching and it receives a request with
12-
prompt_logprobs enabled, which is incompatible."""
1314

15+
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
16+
"""Set up VllmRunner instance."""
1417
monkeypatch.setenv("VLLM_USE_V1", "1")
1518
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
1619
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
20+
return vllm_runner(
21+
MODEL,
22+
dtype=DTYPE,
23+
max_model_len=128,
24+
enforce_eager=True,
25+
enable_prefix_caching=apc,
26+
gpu_memory_utilization=0.5,
27+
)
28+
29+
30+
@pytest.fixture(
31+
# Function scope decouples tests & allows
32+
# env var adjustment via monkeypatch
33+
scope="function",
34+
# Prefix caching
35+
params=[False, True])
36+
def vllm_model(vllm_runner, request, monkeypatch):
37+
"""VllmRunner test fixture parameterized by APC True/False."""
38+
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
39+
yield vllm_model
40+
41+
42+
@pytest.fixture(scope="function")
43+
def vllm_model_apc(vllm_runner, monkeypatch):
44+
"""VllmRunner test fixture with APC."""
45+
with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
46+
yield vllm_model
47+
48+
49+
def _get_test_sampling_params(
50+
prompt_list: List[str],
51+
seed: Optional[int] = 42,
52+
) -> Tuple[List[SamplingParams], List[int]]:
53+
"""Generate random sampling params for a batch."""
54+
55+
def get_mostly_n_gt1() -> int:
56+
"""Mostly n \in [2,20], ~1/3 n=1"""
57+
x = random.randint(0, 28)
58+
if x < 10:
59+
return 1
60+
else:
61+
return x - 8
62+
63+
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
64+
# High temperature to maximize the chance of unique completions
65+
return [
66+
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
67+
for n in n_list
68+
], n_list
69+
70+
71+
def test_parallel_sampling(vllm_model, example_prompts) -> None:
72+
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
73+
74+
Args:
75+
vllm_model: VllmRunner instance under test.
76+
example_prompt: test fixture providing prompts for testing.
77+
"""
78+
sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
79+
model: LLM = vllm_model.model
80+
outputs = model.generate(example_prompts, sampling_params_list)
81+
82+
# Validate each request response
83+
for out, n in zip(outputs, n_list):
84+
completion_counts: Dict[str, int] = {}
85+
# Assert correct number of completions
86+
assert len(out.outputs) == n, (
87+
f"{len(out.outputs)} completions; {n} expected.")
88+
for idx in range(n):
89+
comp = out.outputs[idx]
90+
# Assert correct completion indices
91+
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
92+
text = comp.text
93+
completion_counts[text] = completion_counts.get(text, 0) + 1
94+
# Assert unique completions
95+
if len(completion_counts) != n:
96+
repeats = {
97+
txt: num
98+
for (txt, num) in completion_counts.items() if num > 1
99+
}
100+
raise AssertionError(
101+
f"{len(completion_counts)} unique completions; expected"
102+
f" {n}. Repeats: {repeats}")
103+
104+
105+
def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
106+
"""Test passes if LLMEngine raises an exception when it is configured
107+
for automatic prefix caching and it receives a request with
108+
prompt_logprobs enabled, which is incompatible."""
109+
model: LLM = vllm_model_apc.model
17110
with pytest.raises(ValueError) as excinfo:
18-
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
111+
model.generate(
19112
"Hello, my name is",
20113
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
21114

tests/v1/entrypoints/openai/test_completion.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
250250
assert "".join(chunks) == single_output
251251

252252

253+
@pytest.mark.asyncio
254+
@pytest.mark.parametrize(
255+
"model_name",
256+
[MODEL_NAME],
257+
)
258+
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
259+
model_name: str):
260+
"""Parallel sampling without streaming.
261+
A single request output contains a list of completions.
262+
"""
263+
264+
prompt = "What is an LLM?"
265+
n = 3
266+
max_tokens = 5
267+
268+
# High temperature to maximize chance of unique completions.
269+
completion = await client.completions.create(model=model_name,
270+
prompt=prompt,
271+
max_tokens=max_tokens,
272+
n=n,
273+
temperature=0.95,
274+
stream=False,
275+
seed=42)
276+
277+
# Assert `n` completions
278+
num_completions = len(completion.choices)
279+
assert num_completions == n, (
280+
f"Num completions {num_completions} but expected {n}.")
281+
completion_repeats: Dict[str, int] = {}
282+
for idx, choice in enumerate(completion.choices):
283+
# Assert correct completion index & some finish reason.
284+
assert choice.index == idx, (
285+
f"Index {choice.index} but expected {idx}.")
286+
assert choice.finish_reason is not None, (
287+
"None finish_reason is invalid.")
288+
text = choice.text
289+
completion_repeats[text] = completion_repeats.get(text, 0) + 1
290+
# Assert `n` unique completions
291+
num_unique = len(completion_repeats)
292+
if num_unique != n:
293+
repeats = {
294+
txt: num
295+
for (txt, num) in completion_repeats.items() if num > 1
296+
}
297+
raise AssertionError(
298+
f"Expected {n} unique completions, got {num_unique};"
299+
f" repeats: {repeats}.")
300+
301+
302+
@pytest.mark.asyncio
303+
@pytest.mark.parametrize(
304+
"model_name",
305+
[MODEL_NAME],
306+
)
307+
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
308+
"""Streaming for parallel sampling.
309+
The tokens from multiple samples, are flattened into a single stream,
310+
with an index to indicate which sample the token belongs to.
311+
"""
312+
313+
prompt = "What is an LLM?"
314+
n = 3
315+
max_tokens = 5
316+
317+
stream = await client.completions.create(model=model_name,
318+
prompt=prompt,
319+
max_tokens=max_tokens,
320+
n=n,
321+
temperature=0.95,
322+
stream=True,
323+
seed=42)
324+
chunks: List[List[str]] = [[] for i in range(n)]
325+
finish_reason_count = 0
326+
async for chunk in stream:
327+
index = chunk.choices[0].index
328+
text = chunk.choices[0].text
329+
chunks[index].append(text)
330+
if chunk.choices[0].finish_reason is not None:
331+
finish_reason_count += 1
332+
# Assert `n` completions with correct finish reasons
333+
assert finish_reason_count == n, (
334+
f"Expected {n} completions with valid indices and finish_reason.")
335+
completion_repeats: Dict[str, int] = {}
336+
for chunk in chunks:
337+
chunk_len = len(chunk)
338+
# Assert correct number of completion tokens
339+
assert chunk_len == max_tokens, (
340+
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
341+
text = "".join(chunk)
342+
completion_repeats[text] = completion_repeats.get(text, 0) + 1
343+
print(text)
344+
# Assert `n` unique completions
345+
num_unique = len(completion_repeats)
346+
if num_unique != n:
347+
repeats = {
348+
txt: num
349+
for (txt, num) in completion_repeats.items() if num > 1
350+
}
351+
raise AssertionError(f"{num_unique} unique completions, expected {n};"
352+
f" repeats: {repeats}")
353+
354+
253355
@pytest.mark.asyncio
254356
@pytest.mark.parametrize(
255357
"model_name",

vllm/v1/engine/async_llm.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.utils import cdiv, kill_process_tree
2525
from vllm.v1.engine.core_client import EngineCoreClient
2626
from vllm.v1.engine.output_processor import OutputProcessor
27+
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
2728
from vllm.v1.engine.processor import Processor
2829
from vllm.v1.executor.abstract import Executor
2930
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
@@ -170,7 +171,7 @@ async def add_request(
170171
# requests we don't need to send multiple messages to core proc,
171172
# and so we don't need multiple streams which then get
172173
# re-multiplexed in the API server anyhow.
173-
async def generate(
174+
async def _generate(
174175
self,
175176
prompt: PromptType,
176177
sampling_params: SamplingParams,
@@ -241,6 +242,30 @@ async def generate(
241242
await self.abort(request_id)
242243
raise
243244

245+
def generate(
246+
self,
247+
prompt: PromptType,
248+
sampling_params: SamplingParams,
249+
request_id: str,
250+
lora_request: Optional[LoRARequest] = None,
251+
trace_headers: Optional[Mapping[str, str]] = None,
252+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
253+
priority: int = 0,
254+
) -> AsyncGenerator[RequestOutput, None]:
255+
kwargs = dict(prompt=prompt,
256+
sampling_params=sampling_params,
257+
request_id=request_id,
258+
lora_request=lora_request,
259+
trace_headers=trace_headers,
260+
prompt_adapter_request=prompt_adapter_request,
261+
priority=priority)
262+
if sampling_params.n is None or sampling_params.n == 1:
263+
return self._generate(**kwargs)
264+
else:
265+
# Special handling for parallel sampling requests
266+
return generate_parallel_sampling_async(generate=self._generate,
267+
**kwargs)
268+
244269
async def _run_output_handler(self):
245270
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
246271

vllm/v1/engine/llm_engine.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.usage.usage_lib import UsageContext
2222
from vllm.v1.engine.core_client import EngineCoreClient
2323
from vllm.v1.engine.output_processor import OutputProcessor
24+
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
2425
from vllm.v1.engine.processor import Processor
2526
from vllm.v1.executor.abstract import Executor
2627

@@ -48,6 +49,9 @@ def __init__(
4849
self.model_config = vllm_config.model_config
4950
self.cache_config = vllm_config.cache_config
5051

52+
# Bookkeeping for parallel sampling requests
53+
self.parallel_manager = SyncParallelSamplingManager()
54+
5155
# important: init dp group before init the engine_core
5256
self.parallel_config = vllm_config.parallel_config
5357
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
@@ -115,7 +119,8 @@ def from_engine_args(
115119
multiprocess_mode=enable_multiprocessing)
116120

117121
def get_num_unfinished_requests(self) -> int:
118-
return self.output_processor.get_num_unfinished_requests()
122+
return self.parallel_manager.get_num_unfinished_requests(
123+
self.output_processor.get_num_unfinished_requests())
119124

120125
def has_unfinished_requests(self) -> bool:
121126
has_unfinished = self.output_processor.has_unfinished_requests()
@@ -151,7 +156,36 @@ def add_request(
151156
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
152157
priority: int = 0,
153158
) -> None:
154-
159+
"""Add request."""
160+
kwargs = dict(request_id=request_id,
161+
prompt=prompt,
162+
params=params,
163+
arrival_time=arrival_time,
164+
lora_request=lora_request,
165+
trace_headers=trace_headers,
166+
prompt_adapter_request=prompt_adapter_request,
167+
priority=priority)
168+
# Handle parallel sampling requests differently.
169+
if params is None or isinstance(params,
170+
PoolingParams) or params.n == 1:
171+
self._add_request(**kwargs)
172+
else:
173+
# Special handling for parallel sampling requests
174+
self.parallel_manager.add_request_parallel_sampling(
175+
add_request=self._add_request, **kwargs)
176+
177+
def _add_request(
178+
self,
179+
request_id: str,
180+
prompt: PromptType,
181+
params: Union[SamplingParams, PoolingParams],
182+
arrival_time: Optional[float] = None,
183+
lora_request: Optional[LoRARequest] = None,
184+
trace_headers: Optional[Mapping[str, str]] = None,
185+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
186+
priority: int = 0,
187+
) -> None:
188+
"""Add request, `n=1`"""
155189
# 1) Process raw inputs into the request.
156190
request = self.processor.process_inputs(request_id, prompt, params,
157191
arrival_time, lora_request,
@@ -182,7 +216,10 @@ def step(self) -> List[RequestOutput]:
182216
# 3) Abort any reqs that finished due to stop strings.
183217
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
184218

185-
return processed_outputs.request_outputs
219+
request_outputs = processed_outputs.request_outputs
220+
221+
# 4) Process unfinished parallel sampling requests
222+
return self.parallel_manager.step(request_outputs)
186223

187224
def get_model_config(self):
188225
return self.model_config

0 commit comments

Comments
 (0)