Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,54 @@ def test_schedule_partial_requests():
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
assert requests[2].request_id not in output.num_scheduled_tokens


def test_schedule_concurrent_batches():
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
)

# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.scheduled_new_reqs) == 1
assert scheduler_output0.num_scheduled_tokens[
requests[0].request_id] == 512

# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.scheduled_new_reqs) == 1
assert scheduler_output1.num_scheduled_tokens[
requests[1].request_id] == 512

# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[0],
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(scheduler_output0, model_runner_output)

# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler_output2 = scheduler.schedule()
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1

# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[0],
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
89 changes: 88 additions & 1 deletion tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import threading
import time
import uuid
from concurrent.futures import Future

import pytest
from transformers import AutoTokenizer
Expand All @@ -12,7 +15,9 @@
from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.abstract import Executor, UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -191,3 +196,85 @@ def _check_engine_state():
)
engine_core.add_request(request2)
_check_engine_state()


@fork_new_process_for_each_test
def test_engine_core_concurrent_batches(monkeypatch):
"""
Test that the engine can handle multiple concurrent batches.
"""

def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
request = make_request()
request.sampling_params.max_tokens = max_tokens
return request

class DummyExecutor(UniProcExecutor):

def initialize(self, kv_cache_config: KVCacheConfig) -> None:
super().initialize(kv_cache_config)

# This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1)

def execute_model(
self,
scheduler_output,
) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking."""
future: Future[ModelRunnerOutput] = Future()

def _thread_wrapper(scheduler_output, future):
with self.semaphore:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
output = copy.deepcopy(output[0])
future.set_result(output)

threading.Thread(target=_thread_wrapper,
args=(scheduler_output, future)).start()
return future

@property
def max_concurrent_batches(self) -> int:
return 2

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(
model=MODEL_NAME,
# To test concurrent batches.
max_num_seqs=2,
# Avoid all requests being scheduled once.
enable_prefix_caching=False,
max_num_batched_tokens=10,
)
vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None

# Add two requests in a row.
req = make_request_with_max_tokens(5)
engine_core.add_request(req)
req = make_request_with_max_tokens(5)
engine_core.add_request(req)

# First saturate the batch queue.
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 2
assert engine_core.scheduler.get_num_unfinished_requests() == 2

# Loop through both requests.
while engine_core.scheduler.get_num_unfinished_requests() == 2:
engine_core.step_with_batch_queue()

# Reaching here when got the result of the first request.
while engine_core.scheduler.get_num_unfinished_requests() == 1:
engine_core.step_with_batch_queue()
17 changes: 17 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def __init__(
# Priority queues for requests.
self.waiting: Deque[Request] = deque()
self.running: List[Request] = []
# The requests that have been scheduled and are being executed
# by the executor.
self.scheduled_req_ids: Set[str] = set()

# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
Expand Down Expand Up @@ -118,6 +121,11 @@ def schedule(self) -> "SchedulerOutput":
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.request_id in self.scheduled_req_ids:
# This request has already been scheduled.
req_index += 1
continue

num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
Expand Down Expand Up @@ -164,6 +172,7 @@ def schedule(self) -> "SchedulerOutput":

# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
Expand Down Expand Up @@ -251,6 +260,7 @@ def schedule(self) -> "SchedulerOutput":

self.waiting.popleft()
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
self.request_scheduled(request, scheduled_timestamp)
Expand Down Expand Up @@ -519,6 +529,7 @@ def update_from_output(
stop_reason=request.stop_reason,
events=request.take_events()))

self.scheduled_req_ids.remove(request.request_id)
if not stopped:
new_running.append(request)

Expand Down Expand Up @@ -575,6 +586,8 @@ def finish_requests(

if request.status == RequestStatus.RUNNING:
self.running.remove(request)
if request.request_id in self.scheduled_req_ids:
self.scheduled_req_ids.remove(request.request_id)
else:
self.waiting.remove(request)
request.status = finished_status
Expand All @@ -595,6 +608,10 @@ def get_num_unfinished_requests(self) -> int:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0

def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)

def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()

Expand Down
77 changes: 71 additions & 6 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import signal
import threading
import time
from concurrent.futures import Future
from multiprocessing.connection import Connection
from typing import Any, List, Tuple, Type
from typing import Any, List, Optional, Tuple, Type

import psutil
import zmq
Expand All @@ -17,11 +18,12 @@
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.version import __version__ as VLLM_VERSION
Expand Down Expand Up @@ -65,9 +67,20 @@ def __init__(
log_stats=self.log_stats,
)

# Setup MM Input Mapper.
self.mm_input_cache_server = MMInputCacheServer(
vllm_config.model_config)

# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
if self.batch_queue_size > 1:
self.batch_queue = queue.Queue(self.batch_queue_size)

def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
Expand Down Expand Up @@ -134,7 +147,55 @@ def step(self) -> EngineCoreOutputs:
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output)
scheduler_output, output) # type: ignore
return engine_core_outputs

def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.

The execution flow is as follows:
1. Try to schedule a new batch if there are unscheduled requests
and the job queue is not full. If a new batch is scheduled, directly
return an empty engine core output. In other words, we won't check
and return model outputs before the batch queue is full.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
assert self.batch_queue is not None

engine_core_outputs = None
scheduler_output = None
# If there are unscheduled requests and the job queue
# is not full, schedule a new batch. Note that this is not blocking.
if (self.scheduler.get_num_unscheduled_requests() > 0
and not self.batch_queue.full()):
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore

# If all requests are scheduled or the job queue is full,
# block until the first batch in the job queue is finished.
if (scheduler_output is None
or scheduler_output.total_num_scheduled_tokens == 0):
try:
future, scheduler_output = self.batch_queue.get(
timeout=POLLING_TIMEOUT_S)
# Blocking until the first result is available.
model_output = future.result()
self.batch_queue.task_done()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
except queue.Empty:
# If the queue is empty (timeout at .get), return
# an empty EngineCoreOutputs for logging.
engine_core_outputs = EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())

return engine_core_outputs

def shutdown(self):
Expand Down Expand Up @@ -222,6 +283,9 @@ def signal_handler(signum, frame):
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""

step_fn = (self.step
if self.batch_queue is None else self.step_with_batch_queue)

# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
Expand All @@ -245,10 +309,11 @@ def run_busy_loop(self):
self._handle_client_request(*req)

# 3) Step the engine core.
outputs = self.step()
outputs = step_fn()

# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)
# 4) Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)

def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
Expand Down
Loading