Skip to content

Commit 226ced9

Browse files
benchislettskyloevil
authored andcommitted
[Perf][V1] Fully overlap model execution (vllm-project#23569)
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent 290d183 commit 226ced9

File tree

5 files changed

+252
-31
lines changed

5 files changed

+252
-31
lines changed

vllm/v1/executor/multiproc_executor.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import multiprocessing
44
import os
55
import pickle
6+
import queue
67
import signal
78
import threading
89
import time
@@ -33,7 +34,8 @@
3334
get_loopback_ip, get_mp_context, get_open_port,
3435
set_process_title)
3536
from vllm.v1.executor.abstract import Executor, FailureCallback
36-
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
37+
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
38+
ModelRunnerOutput)
3739
from vllm.worker.worker_base import WorkerWrapperBase
3840

3941
logger = init_logger(__name__)
@@ -414,6 +416,16 @@ def __init__(
414416
# Initializes a message queue for sending the model output
415417
self.worker_response_mq = MessageQueue(1, 1)
416418

419+
scheduler_config = vllm_config.scheduler_config
420+
self.use_async_scheduling = scheduler_config.async_scheduling
421+
if self.use_async_scheduling:
422+
self.async_output_queue: queue.Queue = queue.Queue()
423+
self.async_output_copy_thread = Thread(
424+
target=self.async_output_busy_loop,
425+
daemon=True,
426+
name="WorkerAsyncOutputCopy")
427+
self.async_output_copy_thread.start()
428+
417429
# Initialize device and loads weights
418430
self.worker.init_device()
419431
self.worker.load_model()
@@ -595,6 +607,36 @@ class ResponseStatus(Enum):
595607
SUCCESS = auto()
596608
FAILURE = auto()
597609

610+
def enqueue_output(self, output: Any):
611+
"""Prepares output from the worker and enqueues it to the
612+
worker_response_mq. If the output is an Exception, it is
613+
converted to a FAILURE response.
614+
"""
615+
if isinstance(output, AsyncModelRunnerOutput):
616+
output = output.get_output()
617+
618+
if isinstance(output, Exception):
619+
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
620+
else:
621+
result = (WorkerProc.ResponseStatus.SUCCESS, output)
622+
self.worker_response_mq.enqueue(result)
623+
624+
def handle_output(self, output: Any):
625+
"""Handles output from the worker. If async scheduling is enabled,
626+
it is passed to the async_output_busy_loop thread. Otherwise, it is
627+
enqueued directly to the worker_response_mq.
628+
"""
629+
if self.use_async_scheduling:
630+
self.async_output_queue.put(output)
631+
else:
632+
self.enqueue_output(output)
633+
634+
def async_output_busy_loop(self):
635+
"""Entrypoint for the thread which handles outputs asynchronously."""
636+
while True:
637+
output = self.async_output_queue.get()
638+
self.enqueue_output(output)
639+
598640
def worker_busy_loop(self):
599641
"""Main busy loop for Multiprocessing Workers"""
600642
while True:
@@ -614,10 +656,8 @@ def worker_busy_loop(self):
614656
# exception might not be serializable, so we convert it to
615657
# string, only for logging purpose.
616658
if output_rank is None or self.rank == output_rank:
617-
self.worker_response_mq.enqueue(
618-
(WorkerProc.ResponseStatus.FAILURE, str(e)))
659+
self.handle_output(e)
619660
continue
620661

621662
if output_rank is None or self.rank == output_rank:
622-
self.worker_response_mq.enqueue(
623-
(WorkerProc.ResponseStatus.SUCCESS, output))
663+
self.handle_output(output)

vllm/v1/outputs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from abc import ABC, abstractmethod
45
from dataclasses import dataclass
56
from typing import NamedTuple, Optional
67

@@ -114,6 +115,20 @@ class ModelRunnerOutput:
114115
num_nans_in_logits: Optional[dict[str, int]] = None
115116

116117

118+
# ModelRunnerOutput wrapper for async scheduling.
119+
class AsyncModelRunnerOutput(ABC):
120+
121+
@abstractmethod
122+
def get_output(self) -> ModelRunnerOutput:
123+
"""Get the ModelRunnerOutput for this async output.
124+
125+
This is a blocking call that waits until the results are ready, which
126+
might involve copying device tensors to the host.
127+
This method should only be called once per AsyncModelRunnerOutput.
128+
"""
129+
pass
130+
131+
117132
@dataclass
118133
class DraftTokenIds:
119134

vllm/v1/worker/gpu_input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ def __init__(
250250

251251
self.pooling_params: dict[str, PoolingParams] = {}
252252

253+
# Cached reference to the GPU tensor of previously sampled tokens
254+
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
255+
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
256+
self.prev_req_id_to_index: Optional[dict[str, int]] = None
257+
253258
@property
254259
def req_ids(self) -> list[str]:
255260
# None elements should only be present transiently

0 commit comments

Comments
 (0)