1
1
from typing import Dict , List , Mapping , Optional , Type , Union
2
-
2
+ from dataclasses import dataclass
3
3
from typing_extensions import TypeVar
4
4
5
5
from vllm .config import VllmConfig
26
26
27
27
_G = TypeVar ("_G" , bound = BaseTokenizerGroup , default = BaseTokenizerGroup )
28
28
29
+ def _none_safe_min (x ,y ):
30
+ if x is None :
31
+ return y
32
+ if y is None :
33
+ return x
34
+ return min (x ,y )
35
+
36
+ def _none_safe_max (x ,y ):
37
+ if x is None :
38
+ return y
39
+ if y is None :
40
+ return x
41
+ return max (x ,y )
42
+
43
+ def _none_safe_sum (x ,y ):
44
+ if x is None :
45
+ return y
46
+ if y is None :
47
+ return x
48
+ return x + y
49
+
50
+ @dataclass
51
+ class ParallelSampleChildRequestInfo :
52
+ """Info for aggregating parallel sampling child requests under parent"""
53
+ parent_req_id : str
54
+ index : int
55
+
56
+ @dataclass
57
+ class ParallelSampleParentRequestInfo :
58
+ """Parallel sampling parent request info"""
59
+ n : int
60
+ n_finished : int = 0
61
+
62
+ def num_child_requests_remaining (self ):
63
+ assert self .n >= self .n_finished
64
+ return self .n - self .n_finished
29
65
30
66
class LLMEngine :
31
67
"""Legacy LLMEngine for backwards compatibility."""
@@ -46,6 +82,14 @@ def __init__(
46
82
# TODO: Can we avoid this?
47
83
self .model_config = vllm_config .model_config
48
84
85
+ # Parallel sampling metadata
86
+ # - Metadata for aggregating the child requests associated with a parent request
87
+ self .child_req_id_to_parent_req_info : Dict [
88
+ str , ParallelSampleChildRequestInfo ] = {}
89
+ # - Parent request metadata i.e. degree of parallelism and other characteristics
90
+ self .parent_req_id_info : Dict [str ,
91
+ ParallelSampleParentRequestInfo ] = {}
92
+
49
93
# Tokenizer (+ ensure liveness if running in another process).
50
94
self .tokenizer = init_tokenizer_from_configs (
51
95
model_config = vllm_config .model_config ,
@@ -117,8 +161,52 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
117
161
118
162
return executor_class
119
163
120
- def get_num_unfinished_requests (self ) -> int :
164
+ def _get_num_core_unfinished_requests (self ) -> int :
165
+ """Total number of unfinished requests in engine core
166
+
167
+ Does not account for parallel sampling, i.e. a request
168
+ with `n=3` contributes `(3-n_complete)` to the total
169
+ (the parent request
170
+ does not count); an unfinished request with `n=1`
171
+ contributes 1 to the total.
172
+
173
+ Returns:
174
+ Total requests in engine core
175
+ """
121
176
return self .detokenizer .get_num_unfinished_requests ()
177
+
178
+ def _get_num_parallel_sampling_parent_unfinished_requests (self ) -> int :
179
+ """Total number of requests with parallel sampling
180
+
181
+ i.e. an unfinished request with `n=<blah>` counts as 1,
182
+ all other requests count a 0.
183
+
184
+ Returns:
185
+ Number of parallel sampling parent requests
186
+ """
187
+ return len (self .parent_req_id_info )
188
+
189
+ def _get_num_parallel_sampling_child_unfinished_requests (self ) -> int :
190
+ """Total number of parallel sampling child requests.
191
+
192
+ i.e. an unfinished request with `n>1` counts as `(n-n_complete)`,
193
+ all other requests count as 0.
194
+
195
+ Returns:
196
+ Number of parallel sampling child requests
197
+ """
198
+ return sum ([preq_info .num_child_requests_remaining ()
199
+ for (_ ,preq_info ) in self .parent_req_id_info .items ()])
200
+
201
+ def get_num_unfinished_requests (self ) -> int :
202
+ """Number of unfinished requests.
203
+
204
+ Each request submitted by the user counts as 1; the child requests
205
+ spawned by parallel sampling requests are not reflected in this count.
206
+ """
207
+ return (self ._get_num_core_unfinished_requests () -
208
+ self ._get_num_parallel_sampling_child_unfinished_requests () +
209
+ self ._get_num_parallel_sampling_parent_unfinished_requests ())
122
210
123
211
def has_unfinished_requests (self ) -> bool :
124
212
return self .detokenizer .has_unfinished_requests ()
@@ -127,11 +215,78 @@ def has_unfinished_requests(self) -> bool:
127
215
def validate_outputs (cls , outputs , output_type ):
128
216
return outputs
129
217
218
+ def _forget_parallel_sample_child_request_and_maybe_parent (
219
+ self ,
220
+ child_request_id :str ,
221
+ ) -> None :
222
+ """Forget child request parallel sampling metadata, & its' parent's metadata if necessary.
223
+
224
+ Parent request parallel sampling metadata is forgotten once all child requests have finished.
225
+
226
+ Args:
227
+ child_request_id: id of finished child request
228
+ """
229
+ # Forget child request metadata
230
+ parent_req_id = self .child_req_id_to_parent_req_info [child_request_id ].parent_req_id
231
+ self .child_req_id_to_parent_req_info .pop (child_request_id , None )
232
+ # Track parent request's remaining child requests & erase parent request metadata
233
+ # if there are no remaining child requests
234
+ self .parent_req_id_info [parent_req_id ].n_finished += 1
235
+ if self .parent_req_id_info [parent_req_id ].num_child_requests_remaining () == 0 :
236
+ self .parent_req_id_info .pop (parent_req_id , None )
237
+
238
+ def _maybe_forget_parallel_sample_child_requests (
239
+ self , possible_child_request_ids : List [str ]) -> None :
240
+ """When a request aborts, if it is a child of a parallel sampling request,
241
+ forget its parallel sampling metadata. Apply this to a list of possible child
242
+ request ids. If the request is not associated with parallel sampling, this
243
+ method has no effect on it.
244
+
245
+ Args:
246
+ request_ids: list of request ids to possibly forget parallel sampling metadata for
247
+ """
248
+ for possible_child_req_id in possible_child_request_ids :
249
+ # Check if request is a parallel sampling child request
250
+ if possible_child_req_id in self .child_req_id_to_parent_req_info :
251
+ # If so, forget child request parallel sampling metadata
252
+ self ._forget_parallel_sample_child_request_and_maybe_parent (possible_child_req_id )
253
+
254
+
130
255
def abort_request (self , request_ids : List [str ]) -> None :
131
256
"""Remove request_ids from EngineCore and Detokenizer."""
132
257
133
258
self .engine_core .abort_requests (request_ids )
134
259
self .detokenizer .abort_requests (request_ids )
260
+ self ._maybe_forget_parallel_sample_child_requests (request_ids )
261
+
262
+ def _register_parallel_sampling_parent_request (
263
+ self ,
264
+ parent_req_id : str ,
265
+ parallel_sample_parent_req_info : ParallelSampleParentRequestInfo ,
266
+ ) -> None :
267
+ """Register the attributes associated with a parallel sampling request (i.e. the parent request)"""
268
+ self .parent_req_id_info [
269
+ parent_req_id ] = parallel_sample_parent_req_info
270
+
271
+ def _register_parallel_sampling_child_request (
272
+ self ,
273
+ parallel_sample_child_req_info : ParallelSampleChildRequestInfo ,
274
+ ) -> str :
275
+ """Register the association of a parallel sampling child req with its parent req.
276
+
277
+ Generates a child request id
278
+
279
+ Side effect: internal mapping from child req id -> parent req info structure
280
+
281
+ Returns:
282
+ Child request id
283
+ """
284
+ parent_req_id = parallel_sample_child_req_info .parent_req_id
285
+ index = parallel_sample_child_req_info .index
286
+ child_req_id = f"{ parent_req_id } _parallel_sample_{ index } "
287
+ self .child_req_id_to_parent_req_info [
288
+ child_req_id ] = parallel_sample_child_req_info
289
+ return child_req_id
135
290
136
291
def add_request (
137
292
self ,
@@ -144,6 +299,29 @@ def add_request(
144
299
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
145
300
priority : int = 0 ,
146
301
) -> None :
302
+ if isinstance (params , SamplingParams ) and params .n > 1 :
303
+ # Register parallel sampling request
304
+ n = params .n
305
+ self ._register_parallel_sampling_parent_request (
306
+ request_id , ParallelSampleParentRequestInfo (n ))
307
+ params .n = 1 # Engine core cannot see `n`
308
+ for ndx in range (n ):
309
+ # Register child request with parent
310
+ child_req_id = self ._register_parallel_sampling_child_request (
311
+ ParallelSampleChildRequestInfo (request_id , ndx ))
312
+ # Recurse to add child request; `n=1` prevents further recursion
313
+ self .add_request (
314
+ request_id = child_req_id ,
315
+ prompt = prompt ,
316
+ params = params ,
317
+ arrival_time = arrival_time ,
318
+ lora_request = lora_request ,
319
+ trace_headers = trace_headers ,
320
+ prompt_adapter_request = prompt_adapter_request ,
321
+ priority = priority ,
322
+ )
323
+ # The top-level add_request call is done
324
+ return
147
325
148
326
# 1) Process raw inputs into the request.
149
327
detokenizer_req , engine_core_req = self .processor .process_inputs (
@@ -156,6 +334,80 @@ def add_request(
156
334
# 3) Add the request to EngineCore.
157
335
self .engine_core .add_request (engine_core_req )
158
336
337
+ def _is_parallel_sampling_child_request (
338
+ self ,
339
+ possible_child_request_id :str ,
340
+ ) -> bool :
341
+ return possible_child_request_id in self .child_req_id_to_parent_req_info
342
+
343
+ def _maybe_get_parallel_sampling_child_request_info (
344
+ self ,
345
+ possible_child_request_id : str ,
346
+ ) -> Optional [ParallelSampleChildRequestInfo ]:
347
+ return self .child_req_id_to_parent_req_info .get (possible_child_request_id ,None )
348
+
349
+ def _merge_parallel_sampling_child_request_output_in_place (
350
+ self ,
351
+ parent_req_output : RequestOutput ,
352
+ child_req_output : RequestOutput ,
353
+ ) -> None :
354
+ # Parent is finished when all children are finished
355
+ parent_req_output .finished = parent_req_output .finished and child_req_output .finished
356
+ p_met = parent_req_output .metrics
357
+ c_met = child_req_output .metrics
358
+ if p_met is None :
359
+ # If current parent request metrics are `None`, update with this child's metrics
360
+ # (which may also be None)
361
+ parent_req_output .metrics = c_met
362
+ elif c_met is not None :
363
+ # Only merge in child request output metrics if the child request output metrics
364
+ # are not `None`
365
+ p_met .last_token_time = max (p_met .last_token_time ,c_met .last_token_time )
366
+ p_met .first_scheduled_time = _none_safe_min (p_met .first_scheduled_time ,
367
+ c_met .first_scheduled_time )
368
+ p_met .first_token_time = _none_safe_min (p_met .first_token_time ,c_met .first_token_time )
369
+ p_met .time_in_queue = _none_safe_sum (p_met .time_in_queue ,c_met .time_in_queue )
370
+ p_met .finished_time = _none_safe_max (p_met .finished_time ,c_met .finished_time )
371
+ p_met .last_token_time = max (p_met .last_token_time ,c_met .last_token_time )
372
+ p_met .model_execute_time = _none_safe_sum (p_met .model_execute_time ,c_met .model_execute_time )
373
+ p_met .model_forward_time = _none_safe_sum (p_met .model_forward_time ,c_met .model_forward_time )
374
+ p_met .scheduler_time = _none_safe_sum (p_met .scheduler_time ,c_met .scheduler_time )
375
+ p_met .time_in_queue = _none_safe_sum (p_met .time_in_queue ,c_met .time_in_queue )
376
+ parent_req_output .outputs .extend (child_req_output .outputs )
377
+ parent_req_output .num_cached_tokens = _none_safe_sum (parent_req_output .num_cached_tokens ,
378
+ child_req_output .num_cached_tokens )
379
+
380
+ def _maybe_aggregate_parallel_sampling_child_requests (
381
+ self ,
382
+ request_outputs : List [RequestOutput ],
383
+ ) -> List [RequestOutput ]:
384
+ agg_request_outputs : List [RequestOutput ]= []
385
+ parent_req_id_to_idx : Dict [str ,int ]= {}
386
+ for req_output in request_outputs :
387
+ possible_child_req_id = req_output .request_id
388
+ maybe_child_req_info = self ._maybe_get_parallel_sampling_child_request_info (possible_child_req_id )
389
+ if maybe_child_req_info :
390
+ parent_req_id = maybe_child_req_info .parent_req_id
391
+ if parent_req_id not in parent_req_id_to_idx :
392
+ # For a particular parent id, this is the first child request output we have seen.
393
+ # Repurpose the child request output structure to be the parent request output structure
394
+ req_output .request_id = parent_req_id
395
+ agg_request_outputs .append (req_output )
396
+ # Remember where the parent request output structure resides in the output list
397
+ parent_req_id_to_idx [parent_req_id ]= len (agg_request_outputs )- 1
398
+ else :
399
+ # Merge this child request output into the growing request output data structure associated
400
+ # with its parent.
401
+ parent_req_output = agg_request_outputs [parent_req_id_to_idx [parent_req_id ]]
402
+ self ._merge_parallel_sampling_child_request_output_in_place (parent_req_output ,req_output )
403
+ else :
404
+ # Not a parallel sampling request; don't touch it
405
+ agg_request_outputs .append (req_output )
406
+ return agg_request_outputs
407
+
408
+
409
+
410
+
159
411
def step (self ) -> List [RequestOutput ]:
160
412
161
413
# 1) Get EngineCoreOutput from the EngineCore.
@@ -169,7 +421,9 @@ def step(self) -> List[RequestOutput]:
169
421
if requests_to_abort :
170
422
self .abort_request (requests_to_abort )
171
423
172
- return request_outputs
424
+ # 4) If necessary, aggregate outputs for parallel sampling child requests
425
+ # to be associated with parent request
426
+ return self ._maybe_aggregate_parallel_sampling_child_requests (request_outputs )
173
427
174
428
# TODO(rob): Can we get rid of these?
175
429
0 commit comments