Skip to content

Commit f506458

Browse files
committed
first pass at n>1
Signed-off-by: Andrew Feldman <[email protected]>
1 parent d927dbc commit f506458

File tree

1 file changed

+257
-3
lines changed

1 file changed

+257
-3
lines changed

vllm/v1/engine/llm_engine.py

Lines changed: 257 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Dict, List, Mapping, Optional, Type, Union
2-
2+
from dataclasses import dataclass
33
from typing_extensions import TypeVar
44

55
from vllm.config import VllmConfig
@@ -26,6 +26,42 @@
2626

2727
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
2828

29+
def _none_safe_min(x,y):
30+
if x is None:
31+
return y
32+
if y is None:
33+
return x
34+
return min(x,y)
35+
36+
def _none_safe_max(x,y):
37+
if x is None:
38+
return y
39+
if y is None:
40+
return x
41+
return max(x,y)
42+
43+
def _none_safe_sum(x,y):
44+
if x is None:
45+
return y
46+
if y is None:
47+
return x
48+
return x+y
49+
50+
@dataclass
51+
class ParallelSampleChildRequestInfo:
52+
"""Info for aggregating parallel sampling child requests under parent"""
53+
parent_req_id: str
54+
index: int
55+
56+
@dataclass
57+
class ParallelSampleParentRequestInfo:
58+
"""Parallel sampling parent request info"""
59+
n: int
60+
n_finished: int = 0
61+
62+
def num_child_requests_remaining(self):
63+
assert self.n >= self.n_finished
64+
return self.n - self.n_finished
2965

3066
class LLMEngine:
3167
"""Legacy LLMEngine for backwards compatibility."""
@@ -46,6 +82,14 @@ def __init__(
4682
# TODO: Can we avoid this?
4783
self.model_config = vllm_config.model_config
4884

85+
# Parallel sampling metadata
86+
# - Metadata for aggregating the child requests associated with a parent request
87+
self.child_req_id_to_parent_req_info: Dict[
88+
str, ParallelSampleChildRequestInfo] = {}
89+
# - Parent request metadata i.e. degree of parallelism and other characteristics
90+
self.parent_req_id_info: Dict[str,
91+
ParallelSampleParentRequestInfo] = {}
92+
4993
# Tokenizer (+ ensure liveness if running in another process).
5094
self.tokenizer = init_tokenizer_from_configs(
5195
model_config=vllm_config.model_config,
@@ -117,8 +161,52 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
117161

118162
return executor_class
119163

120-
def get_num_unfinished_requests(self) -> int:
164+
def _get_num_core_unfinished_requests(self) -> int:
165+
"""Total number of unfinished requests in engine core
166+
167+
Does not account for parallel sampling, i.e. a request
168+
with `n=3` contributes `(3-n_complete)` to the total
169+
(the parent request
170+
does not count); an unfinished request with `n=1`
171+
contributes 1 to the total.
172+
173+
Returns:
174+
Total requests in engine core
175+
"""
121176
return self.detokenizer.get_num_unfinished_requests()
177+
178+
def _get_num_parallel_sampling_parent_unfinished_requests(self) -> int:
179+
"""Total number of requests with parallel sampling
180+
181+
i.e. an unfinished request with `n=<blah>` counts as 1,
182+
all other requests count a 0.
183+
184+
Returns:
185+
Number of parallel sampling parent requests
186+
"""
187+
return len(self.parent_req_id_info)
188+
189+
def _get_num_parallel_sampling_child_unfinished_requests(self) -> int:
190+
"""Total number of parallel sampling child requests.
191+
192+
i.e. an unfinished request with `n>1` counts as `(n-n_complete)`,
193+
all other requests count as 0.
194+
195+
Returns:
196+
Number of parallel sampling child requests
197+
"""
198+
return sum([preq_info.num_child_requests_remaining()
199+
for (_,preq_info) in self.parent_req_id_info.items()])
200+
201+
def get_num_unfinished_requests(self) -> int:
202+
"""Number of unfinished requests.
203+
204+
Each request submitted by the user counts as 1; the child requests
205+
spawned by parallel sampling requests are not reflected in this count.
206+
"""
207+
return (self._get_num_core_unfinished_requests() -
208+
self._get_num_parallel_sampling_child_unfinished_requests() +
209+
self._get_num_parallel_sampling_parent_unfinished_requests())
122210

123211
def has_unfinished_requests(self) -> bool:
124212
return self.detokenizer.has_unfinished_requests()
@@ -127,11 +215,78 @@ def has_unfinished_requests(self) -> bool:
127215
def validate_outputs(cls, outputs, output_type):
128216
return outputs
129217

218+
def _forget_parallel_sample_child_request_and_maybe_parent(
219+
self,
220+
child_request_id:str,
221+
) -> None:
222+
"""Forget child request parallel sampling metadata, & its' parent's metadata if necessary.
223+
224+
Parent request parallel sampling metadata is forgotten once all child requests have finished.
225+
226+
Args:
227+
child_request_id: id of finished child request
228+
"""
229+
# Forget child request metadata
230+
parent_req_id=self.child_req_id_to_parent_req_info[child_request_id].parent_req_id
231+
self.child_req_id_to_parent_req_info.pop(child_request_id, None)
232+
# Track parent request's remaining child requests & erase parent request metadata
233+
# if there are no remaining child requests
234+
self.parent_req_id_info[parent_req_id].n_finished+=1
235+
if self.parent_req_id_info[parent_req_id].num_child_requests_remaining() == 0:
236+
self.parent_req_id_info.pop(parent_req_id, None)
237+
238+
def _maybe_forget_parallel_sample_child_requests(
239+
self, possible_child_request_ids: List[str]) -> None:
240+
"""When a request aborts, if it is a child of a parallel sampling request,
241+
forget its parallel sampling metadata. Apply this to a list of possible child
242+
request ids. If the request is not associated with parallel sampling, this
243+
method has no effect on it.
244+
245+
Args:
246+
request_ids: list of request ids to possibly forget parallel sampling metadata for
247+
"""
248+
for possible_child_req_id in possible_child_request_ids:
249+
# Check if request is a parallel sampling child request
250+
if possible_child_req_id in self.child_req_id_to_parent_req_info:
251+
# If so, forget child request parallel sampling metadata
252+
self._forget_parallel_sample_child_request_and_maybe_parent(possible_child_req_id)
253+
254+
130255
def abort_request(self, request_ids: List[str]) -> None:
131256
"""Remove request_ids from EngineCore and Detokenizer."""
132257

133258
self.engine_core.abort_requests(request_ids)
134259
self.detokenizer.abort_requests(request_ids)
260+
self._maybe_forget_parallel_sample_child_requests(request_ids)
261+
262+
def _register_parallel_sampling_parent_request(
263+
self,
264+
parent_req_id: str,
265+
parallel_sample_parent_req_info: ParallelSampleParentRequestInfo,
266+
) -> None:
267+
"""Register the attributes associated with a parallel sampling request (i.e. the parent request)"""
268+
self.parent_req_id_info[
269+
parent_req_id] = parallel_sample_parent_req_info
270+
271+
def _register_parallel_sampling_child_request(
272+
self,
273+
parallel_sample_child_req_info: ParallelSampleChildRequestInfo,
274+
) -> str:
275+
"""Register the association of a parallel sampling child req with its parent req.
276+
277+
Generates a child request id
278+
279+
Side effect: internal mapping from child req id -> parent req info structure
280+
281+
Returns:
282+
Child request id
283+
"""
284+
parent_req_id = parallel_sample_child_req_info.parent_req_id
285+
index = parallel_sample_child_req_info.index
286+
child_req_id = f"{parent_req_id}_parallel_sample_{index}"
287+
self.child_req_id_to_parent_req_info[
288+
child_req_id] = parallel_sample_child_req_info
289+
return child_req_id
135290

136291
def add_request(
137292
self,
@@ -144,6 +299,29 @@ def add_request(
144299
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
145300
priority: int = 0,
146301
) -> None:
302+
if isinstance(params, SamplingParams) and params.n > 1:
303+
# Register parallel sampling request
304+
n = params.n
305+
self._register_parallel_sampling_parent_request(
306+
request_id, ParallelSampleParentRequestInfo(n))
307+
params.n = 1 # Engine core cannot see `n`
308+
for ndx in range(n):
309+
# Register child request with parent
310+
child_req_id = self._register_parallel_sampling_child_request(
311+
ParallelSampleChildRequestInfo(request_id, ndx))
312+
# Recurse to add child request; `n=1` prevents further recursion
313+
self.add_request(
314+
request_id=child_req_id,
315+
prompt=prompt,
316+
params=params,
317+
arrival_time=arrival_time,
318+
lora_request=lora_request,
319+
trace_headers=trace_headers,
320+
prompt_adapter_request=prompt_adapter_request,
321+
priority=priority,
322+
)
323+
# The top-level add_request call is done
324+
return
147325

148326
# 1) Process raw inputs into the request.
149327
detokenizer_req, engine_core_req = self.processor.process_inputs(
@@ -156,6 +334,80 @@ def add_request(
156334
# 3) Add the request to EngineCore.
157335
self.engine_core.add_request(engine_core_req)
158336

337+
def _is_parallel_sampling_child_request(
338+
self,
339+
possible_child_request_id:str,
340+
) -> bool:
341+
return possible_child_request_id in self.child_req_id_to_parent_req_info
342+
343+
def _maybe_get_parallel_sampling_child_request_info(
344+
self,
345+
possible_child_request_id: str,
346+
) -> Optional[ParallelSampleChildRequestInfo]:
347+
return self.child_req_id_to_parent_req_info.get(possible_child_request_id,None)
348+
349+
def _merge_parallel_sampling_child_request_output_in_place(
350+
self,
351+
parent_req_output: RequestOutput,
352+
child_req_output: RequestOutput,
353+
) -> None:
354+
# Parent is finished when all children are finished
355+
parent_req_output.finished=parent_req_output.finished and child_req_output.finished
356+
p_met=parent_req_output.metrics
357+
c_met=child_req_output.metrics
358+
if p_met is None:
359+
# If current parent request metrics are `None`, update with this child's metrics
360+
# (which may also be None)
361+
parent_req_output.metrics=c_met
362+
elif c_met is not None:
363+
# Only merge in child request output metrics if the child request output metrics
364+
# are not `None`
365+
p_met.last_token_time=max(p_met.last_token_time,c_met.last_token_time)
366+
p_met.first_scheduled_time=_none_safe_min(p_met.first_scheduled_time,
367+
c_met.first_scheduled_time)
368+
p_met.first_token_time=_none_safe_min(p_met.first_token_time,c_met.first_token_time)
369+
p_met.time_in_queue=_none_safe_sum(p_met.time_in_queue,c_met.time_in_queue)
370+
p_met.finished_time=_none_safe_max(p_met.finished_time,c_met.finished_time)
371+
p_met.last_token_time=max(p_met.last_token_time,c_met.last_token_time)
372+
p_met.model_execute_time=_none_safe_sum(p_met.model_execute_time,c_met.model_execute_time)
373+
p_met.model_forward_time=_none_safe_sum(p_met.model_forward_time,c_met.model_forward_time)
374+
p_met.scheduler_time=_none_safe_sum(p_met.scheduler_time,c_met.scheduler_time)
375+
p_met.time_in_queue=_none_safe_sum(p_met.time_in_queue,c_met.time_in_queue)
376+
parent_req_output.outputs.extend(child_req_output.outputs)
377+
parent_req_output.num_cached_tokens=_none_safe_sum(parent_req_output.num_cached_tokens,
378+
child_req_output.num_cached_tokens)
379+
380+
def _maybe_aggregate_parallel_sampling_child_requests(
381+
self,
382+
request_outputs: List[RequestOutput],
383+
) -> List[RequestOutput]:
384+
agg_request_outputs: List[RequestOutput]=[]
385+
parent_req_id_to_idx: Dict[str,int]={}
386+
for req_output in request_outputs:
387+
possible_child_req_id=req_output.request_id
388+
maybe_child_req_info = self._maybe_get_parallel_sampling_child_request_info(possible_child_req_id)
389+
if maybe_child_req_info:
390+
parent_req_id=maybe_child_req_info.parent_req_id
391+
if parent_req_id not in parent_req_id_to_idx:
392+
# For a particular parent id, this is the first child request output we have seen.
393+
# Repurpose the child request output structure to be the parent request output structure
394+
req_output.request_id=parent_req_id
395+
agg_request_outputs.append(req_output)
396+
# Remember where the parent request output structure resides in the output list
397+
parent_req_id_to_idx[parent_req_id]=len(agg_request_outputs)-1
398+
else:
399+
# Merge this child request output into the growing request output data structure associated
400+
# with its parent.
401+
parent_req_output=agg_request_outputs[parent_req_id_to_idx[parent_req_id]]
402+
self._merge_parallel_sampling_child_request_output_in_place(parent_req_output,req_output)
403+
else:
404+
# Not a parallel sampling request; don't touch it
405+
agg_request_outputs.append(req_output)
406+
return agg_request_outputs
407+
408+
409+
410+
159411
def step(self) -> List[RequestOutput]:
160412

161413
# 1) Get EngineCoreOutput from the EngineCore.
@@ -169,7 +421,9 @@ def step(self) -> List[RequestOutput]:
169421
if requests_to_abort:
170422
self.abort_request(requests_to_abort)
171423

172-
return request_outputs
424+
# 4) If necessary, aggregate outputs for parallel sampling child requests
425+
# to be associated with parent request
426+
return self._maybe_aggregate_parallel_sampling_child_requests(request_outputs)
173427

174428
# TODO(rob): Can we get rid of these?
175429

0 commit comments

Comments
 (0)