From f8a92a9c9e5f59ed23963a8a6da35725d4f04524 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 7 Feb 2025 17:33:02 -0800 Subject: [PATCH 1/6] work Signed-off-by: Cody Yu --- tests/v1/engine/test_engine_core.py | 87 ++++++++++++++++++++++++++++- vllm/v1/core/scheduler.py | 19 ++++++- vllm/v1/engine/core.py | 81 +++++++++++++++++++++++++-- vllm/v1/executor/abstract.py | 9 ++- 4 files changed, 186 insertions(+), 10 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 36b31550dc0e..6bd3004ae070 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +import copy +import threading import time import uuid +from concurrent.futures import Future import pytest from transformers import AutoTokenizer @@ -12,7 +15,9 @@ from vllm.platforms import current_platform from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.abstract import Executor, UniProcExecutor +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -191,3 +196,83 @@ def _check_engine_state(): ) engine_core.add_request(request2) _check_engine_state() + + +@fork_new_process_for_each_test +def test_engine_core_concurrent_batches(monkeypatch): + """ + Test that the engine can handle multiple concurrent batches. + """ + + class DummyExecutor(UniProcExecutor): + + def initialize(self, kv_cache_config: KVCacheConfig) -> None: + super().initialize(kv_cache_config) + + # This executor actually can only run 1 batch at a time + self.semaphore = threading.Semaphore(1) + + def execute_model( + self, + scheduler_output, + ) -> Future[ModelRunnerOutput]: + """Make execute_model non-blocking.""" + future: Future[ModelRunnerOutput] = Future() + + def _thread_wrapper(scheduler_output, future): + with self.semaphore: + output = self.collective_rpc("execute_model", + args=(scheduler_output, )) + # Make a copy because output[0] may be reused + # by the next batch. + output = copy.deepcopy(output[0]) + future.set_result(output) + + threading.Thread(target=_thread_wrapper, + args=(scheduler_output, future)).start() + return future + + @property + def max_concurrent_batches(self) -> int: + return 2 + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine_args = EngineArgs( + model=MODEL_NAME, + # To test concurrent batches. + max_num_seqs=2, + # Avoid all requests being scheduled once. + enable_prefix_caching=False, + max_num_batched_tokens=10, + ) + vllm_config = engine_args.create_engine_config() + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=DummyExecutor) + assert engine_core.batch_queue is not None + + # Add two requests in a row. + req = make_request() + print(f"Adding request: {req.request_id}") + req.sampling_params.max_tokens = 5 + engine_core.add_request(req) + req = make_request() + print(f"Adding request: {req.request_id}") + req.sampling_params.max_tokens = 5 + engine_core.add_request(req) + + # First saturate the batch queue. + assert engine_core.step_with_batch_queue() is None + assert engine_core.batch_queue.qsize() == 1 + assert engine_core.step_with_batch_queue() is None + assert engine_core.batch_queue.qsize() == 2 + assert engine_core.scheduler.get_num_unfinished_requests() == 2 + + # Loop through both requests. + while engine_core.scheduler.get_num_unfinished_requests() == 2: + engine_core.step_with_batch_queue() + + # Reaching here when got the result of the first request. + while engine_core.scheduler.get_num_unfinished_requests() == 1: + engine_core.step_with_batch_queue() diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e32e557ae232..436674076bb7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -58,6 +58,9 @@ def __init__( # Priority queues for requests. self.waiting: Deque[Request] = deque() self.running: List[Request] = [] + # The requests that have been scheduled and are being executed + # by the executor. + self.scheduled_req_ids: Set[str] = set() # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -118,6 +121,11 @@ def schedule(self) -> "SchedulerOutput": req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + if request.request_id in self.scheduled_req_ids: + # This request has already been scheduled. + req_index += 1 + continue + num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -164,6 +172,7 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) + self.scheduled_req_ids.add(request.request_id) req_to_new_block_ids[request.request_id] = [ b.block_id for b in new_blocks ] @@ -251,6 +260,7 @@ def schedule(self) -> "SchedulerOutput": self.waiting.popleft() self.running.append(request) + self.scheduled_req_ids.add(request.request_id) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) self.request_scheduled(request, scheduled_timestamp) @@ -292,6 +302,7 @@ def schedule(self) -> "SchedulerOutput": # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. + # FIXME: This is not correct. num_common_prefix_blocks = 0 if self.running: any_request = self.running[0] @@ -452,7 +463,7 @@ def update_from_output( req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: - # The request was not scheduled in this step. + # The request was not scheduled in this batch. new_running.append(request) continue @@ -519,6 +530,7 @@ def update_from_output( stop_reason=request.stop_reason, events=request.take_events())) + self.scheduled_req_ids.remove(request.request_id) if not stopped: new_running.append(request) @@ -575,6 +587,8 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) + if request.request_id in self.scheduled_req_ids: + self.scheduled_req_ids.remove(request.request_id) else: self.waiting.remove(request) request.status = finished_status @@ -595,6 +609,9 @@ def get_num_unfinished_requests(self) -> int: def has_unfinished_requests(self) -> bool: return self.get_num_unfinished_requests() > 0 + def get_num_unscheduled_requests(self) -> int: + return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) + def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4642ac1778ed..921cda044d0b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -4,8 +4,9 @@ import signal import threading import time +from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import Any, List, Tuple, Type +from typing import Any, List, Optional, Tuple, Type import psutil import zmq @@ -17,11 +18,12 @@ maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import get_kv_cache_configs -from vllm.v1.core.scheduler import Scheduler +from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.version import __version__ as VLLM_VERSION @@ -65,9 +67,24 @@ def __init__( log_stats=self.log_stats, ) + # Setup MM Input Mapper. self.mm_input_cache_server = MMInputCacheServer( vllm_config.model_config) + # Setup batch queue for pipeline parallelism. + # Batch queue for scheduled batches. This enables us to asynchronously + # schedule and execute batches, and is required by pipeline parallelism + # to eliminate pipeline bubbles. + self.batch_queue_size = self.model_executor.max_concurrent_batches + self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput], + SchedulerOutput]]] = None + if self.batch_queue_size > 0: + if self.batch_queue_size == 1: + logger.warning( + "batch_queue_size=1 may result in suboptimal " + "performance and should only be used for testing") + self.batch_queue = queue.Queue(self.batch_queue_size) + def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() @@ -134,7 +151,55 @@ def step(self) -> EngineCoreOutputs: scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, output) + scheduler_output, output) # type: ignore + return engine_core_outputs + + def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: + """Schedule and execute batches with the batch queue. + Note that if nothing to output in this step, None is returned. + + The execution flow is as follows: + 1. Try to schedule a new batch if there are unscheduled requests + and the job queue is not full. If a new batch is scheduled, directly + return an empty engine core output. In other words, we won't check + and return model outputs before the batch queue is full. + 2. If there is no new scheduled batch, meaning that the batch queue + is full or no other requests can be scheduled, we block until the first + batch in the job queue is finished. + 3. Update the scheduler from the output. + """ + assert self.batch_queue is not None + + engine_core_outputs = None + scheduler_output = None + # If there are unscheduled requests and the job queue + # is not full, schedule a new batch. Note that this is not blocking. + if (self.scheduler.get_num_unscheduled_requests() > 0 + and not self.batch_queue.full()): + scheduler_output = self.scheduler.schedule() + if scheduler_output.total_num_scheduled_tokens > 0: + future = self.model_executor.execute_model(scheduler_output) + self.batch_queue.put_nowait( + (future, scheduler_output)) # type: ignore + + # If all requests are scheduled or the job queue is full, + # block until the first batch in the job queue is finished. + if (scheduler_output is None + or scheduler_output.total_num_scheduled_tokens == 0): + try: + future, scheduler_output = self.batch_queue.get( + timeout=POLLING_TIMEOUT_S) + # Blocking until the first result is available. + model_output = future.result() + self.batch_queue.task_done() + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output) + except queue.Empty: + # If the queue is empty (timeout at .get), return + # an empty EngineCoreOutputs for logging. + engine_core_outputs = EngineCoreOutputs( + outputs=[], scheduler_stats=self.scheduler.make_stats()) + return engine_core_outputs def shutdown(self): @@ -222,6 +287,9 @@ def signal_handler(signum, frame): def run_busy_loop(self): """Core busy loop of the EngineCore.""" + step_fn = (self.step + if self.batch_queue is None else self.step_with_batch_queue) + # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. @@ -245,10 +313,11 @@ def run_busy_loop(self): self._handle_client_request(*req) # 3) Step the engine core. - outputs = self.step() + outputs = step_fn() - # 5) Put EngineCoreOutputs into the output queue. - self.output_queue.put_nowait(outputs) + # 4) Put EngineCoreOutputs into the output queue. + if outputs is not None: + self.output_queue.put_nowait(outputs) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index d1ffc891ad69..488a7e00d87f 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Type +from typing import List, Type, Union +from concurrent.futures import Future from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase @@ -70,11 +71,15 @@ def get_kv_cache_specs(self) -> List[KVCacheSpec]: def execute_model( self, scheduler_output, - ) -> ModelRunnerOutput: + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: output = self.collective_rpc("execute_model", args=(scheduler_output, )) return output[0] + @property + def max_concurrent_batches(self) -> int: + return 1 + def profile(self, is_start: bool = True): self.collective_rpc("profile", args=(is_start, )) From 6eca4a4020e09c79b2afd951cadd468bd5895f17 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 10 Feb 2025 16:49:48 -0800 Subject: [PATCH 2/6] test Signed-off-by: Cody Yu --- tests/v1/engine/test_engine_core.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 6bd3004ae070..2ad8beab5f77 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -204,6 +204,11 @@ def test_engine_core_concurrent_batches(monkeypatch): Test that the engine can handle multiple concurrent batches. """ + def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: + request = make_request() + request.sampling_params.max_tokens = max_tokens + return request + class DummyExecutor(UniProcExecutor): def initialize(self, kv_cache_config: KVCacheConfig) -> None: @@ -253,13 +258,9 @@ def max_concurrent_batches(self) -> int: assert engine_core.batch_queue is not None # Add two requests in a row. - req = make_request() - print(f"Adding request: {req.request_id}") - req.sampling_params.max_tokens = 5 + req = make_request_with_max_tokens(5) engine_core.add_request(req) - req = make_request() - print(f"Adding request: {req.request_id}") - req.sampling_params.max_tokens = 5 + req = make_request_with_max_tokens(5) engine_core.add_request(req) # First saturate the batch queue. From 5e42beaf732f4e82be5fd82703afd55456737635 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 10 Feb 2025 17:31:43 -0800 Subject: [PATCH 3/6] done Signed-off-by: Cody Yu --- tests/v1/core/test_scheduler.py | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8aba46aec477..97f75d0fd70c 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -213,3 +213,54 @@ def test_schedule_partial_requests(): assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[1].request_id] == 700 assert requests[2].request_id not in output.num_scheduled_tokens + + +def test_schedule_concurrent_batches(): + scheduler = create_scheduler( + max_num_batched_tokens=1024, + max_num_seqs=2, + ) + requests = create_requests( + num_requests=2, + num_tokens=512, + ) + + # Schedule the first request. + scheduler.add_request(requests[0]) + scheduler_output0 = scheduler.schedule() + assert len(scheduler_output0.scheduled_new_reqs) == 1 + assert scheduler_output0.num_scheduled_tokens[ + requests[0].request_id] == 512 + + # The first request is still running, so only schedule the second request. + scheduler.add_request(requests[1]) + scheduler_output1 = scheduler.schedule() + assert len(scheduler_output1.scheduled_new_reqs) == 1 + assert scheduler_output1.num_scheduled_tokens[ + requests[1].request_id] == 512 + + # Model output of the first request. + model_runner_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[0], + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(scheduler_output0, model_runner_output) + + # Schedule the next step. + # The first request can be scheduled again while the second + # request is still running. + scheduler_output2 = scheduler.schedule() + assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1 + + # Model output of the second request. + model_runner_output = ModelRunnerOutput( + req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, + sampled_token_ids=[0], + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(scheduler_output1, model_runner_output) From 568be950b9de599ec413c7218c824132648d9f91 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 13 Feb 2025 13:07:00 -0800 Subject: [PATCH 4/6] work Signed-off-by: Cody Yu --- tests/v1/engine/test_engine_core.py | 1 + vllm/v1/engine/core.py | 6 +- vllm/v1/executor/abstract.py | 12 ++-- vllm/v1/executor/ray_distributed_executor.py | 61 ++++++++++++++++++++ 4 files changed, 67 insertions(+), 13 deletions(-) create mode 100644 vllm/v1/executor/ray_distributed_executor.py diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 2ad8beab5f77..d035668098eb 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -254,6 +254,7 @@ def max_concurrent_batches(self) -> int: ) vllm_config = engine_args.create_engine_config() engine_core = EngineCore(vllm_config=vllm_config, + log_stats=False, executor_class=DummyExecutor) assert engine_core.batch_queue is not None diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 921cda044d0b..cc9d7172218e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -78,11 +78,7 @@ def __init__( self.batch_queue_size = self.model_executor.max_concurrent_batches self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput], SchedulerOutput]]] = None - if self.batch_queue_size > 0: - if self.batch_queue_size == 1: - logger.warning( - "batch_queue_size=1 may result in suboptimal " - "performance and should only be used for testing") + if self.batch_queue_size > 1: self.batch_queue = queue.Queue(self.batch_queue_size) def _initialize_kv_caches(self, diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 488a7e00d87f..304115a122d8 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Type, Union from concurrent.futures import Future +from typing import List, Type, Union from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase -from vllm.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0) from vllm.executor.uniproc_executor import ( # noqa ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) from vllm.executor.uniproc_executor import ( # noqa @@ -34,7 +32,9 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: f"ExecutorBase. Got {distributed_executor_backend}.") executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": - executor_class = RayDistributedExecutor + from vllm.v1.executor.ray_distributed_executor import ( # noqa + RayDistributedExecutor as RayDistributedExecutorV0) + executor_class = RayDistributedExecutorV0 elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor executor_class = MultiprocExecutor @@ -90,7 +90,3 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): pass - - -class RayDistributedExecutor(RayDistributedExecutorV0, Executor): - pass diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py new file mode 100644 index 000000000000..3378381c7ade --- /dev/null +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +from concurrent.futures import Future +from typing import Union + +from vllm.executor.ray_distributed_executor import ( # noqa + RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.v1.executor.abstract import Executor +from vllm.v1.outputs import ModelRunnerOutput + + +class FutureWrapper(Future): + """A wrapper around a Ray output reference to meet the interface + of .execute_model(). + """ + + def __init__(self, ref): + super().__init__() + self.ref = ref + + def result(self, timeout=None): + if timeout is not None: + raise NotImplementedError("timeout is not supported") + return self.ref.get() + + +class RayDistributedExecutor(RayDistributedExecutorV0, Executor): + """Ray distributed executor using Ray Compiled Graphs.""" + + @property + def max_concurrent_batches(self) -> int: + """Ray distributed executor supports pipeline parallelism, + meaning that it allows PP size batches to be executed concurrently. + """ + return self.vllm_config.parallel_config.pipeline_parallel_size + + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + """Execute the model on the Ray workers. + + Args: + scheduler_output: The scheduler output to execute. + + Returns: + The model runner output. + """ + # Build the compiled DAG for the first time. + if self.forward_dag is None: # type: ignore + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + refs = self.forward_dag.execute(scheduler_output) # type: ignore + + # When PP is not used, we block here until the result is available. + if self.max_concurrent_batches == 1: + return refs[0].get() + + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs[0]) From 82567a36b2841d546bcd33d469a58b9b309384a6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 13 Feb 2025 13:08:39 -0800 Subject: [PATCH 5/6] doc Signed-off-by: Cody Yu --- vllm/v1/core/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 436674076bb7..2d5a1192c227 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -302,7 +302,6 @@ def schedule(self) -> "SchedulerOutput": # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - # FIXME: This is not correct. num_common_prefix_blocks = 0 if self.running: any_request = self.running[0] @@ -463,7 +462,7 @@ def update_from_output( req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: - # The request was not scheduled in this batch. + # The request was not scheduled in this step. new_running.append(request) continue @@ -610,6 +609,7 @@ def has_unfinished_requests(self) -> bool: return self.get_num_unfinished_requests() > 0 def get_num_unscheduled_requests(self) -> int: + """Number of requests that are not being processed by the executor.""" return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) def reset_prefix_cache(self) -> bool: From ed57d1d273221dcbb9835e1ae8a847302151d2a3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 13 Feb 2025 14:24:17 -0800 Subject: [PATCH 6/6] fix Signed-off-by: Cody Yu --- vllm/v1/engine/core.py | 2 ++ vllm/v1/executor/abstract.py | 4 ++-- vllm/v1/executor/ray_distributed_executor.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index cc9d7172218e..bda343bdb7a6 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -79,6 +79,8 @@ def __init__( self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput], SchedulerOutput]]] = None if self.batch_queue_size > 1: + logger.info("Batch queue is enabled with size %d", + self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) def _initialize_kv_caches(self, diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 304115a122d8..3663cbd08aec 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -33,8 +33,8 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": from vllm.v1.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0) - executor_class = RayDistributedExecutorV0 + RayDistributedExecutor) + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor executor_class = MultiprocExecutor diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 3378381c7ade..53548610adf6 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -32,7 +32,7 @@ def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, meaning that it allows PP size batches to be executed concurrently. """ - return self.vllm_config.parallel_config.pipeline_parallel_size + return 1 #self.vllm_config.parallel_config.pipeline_parallel_size def execute_model( self,