From c727e1fb8d37ffbabc75112735c83f3647601147 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 6 Mar 2025 16:24:13 -0800 Subject: [PATCH 1/3] [V1] Eagerly remove finished requests from the batch Currently, if all in-flight requests are aborted, the core engine loop will pause and only remove the completed requests from the batch the next time there are new requests to schedule. We might as well do this early in this case so that it's not on the later critical path. Other advantages of this: - It makes it easy to return a final CoreRequestOutputs with stats that reflect that all requests are finished, so that they can be updated without having to poll the engine. - It will help with the data parallel implementation where we need to coordinate when to pause multiple engines based on there being no in-flight requests. Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core.py | 10 ++++++++++ vllm/v1/core/scheduler.py | 6 ++++++ vllm/v1/engine/core.py | 4 ++-- vllm/v1/outputs.py | 10 ++++++++++ vllm/v1/worker/gpu_model_runner.py | 9 ++++++--- vllm/v1/worker/tpu_model_runner.py | 6 +++++- 6 files changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 11c22effb122..5fdbcf5b9963 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -102,14 +102,24 @@ def test_engine_core(monkeypatch): engine_core.add_request(req) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 + assert engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() _ = engine_core.step() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 1 + assert engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() engine_core.abort_requests([request_id]) assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 + assert not engine_core.scheduler.has_unfinished_requests() + assert engine_core.scheduler.has_finished_requests() + + _ = engine_core.step() + assert not engine_core.scheduler.has_unfinished_requests() + assert not engine_core.scheduler.has_finished_requests() # Add, step, abort 1 of the 3. req0 = make_request() diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index db14c9455a1f..71d2849c330b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -657,6 +657,12 @@ def get_num_unfinished_requests(self) -> int: def has_unfinished_requests(self) -> bool: return self.get_num_unfinished_requests() > 0 + def has_finished_requests(self) -> bool: + return len(self.finished_req_ids) > 0 + + def has_requests(self): + return self.has_unfinished_requests() or self.has_finished_requests() + def get_num_unscheduled_requests(self) -> int: """Number of requests that are not being processed by the executor.""" return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 671a72e2112d..9d22fd0c70fe 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -146,7 +146,7 @@ def abort_requests(self, request_ids: list[str]): def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" - if not self.scheduler.has_unfinished_requests(): + if not self.scheduler.has_requests(): return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) scheduler_output = self.scheduler.schedule() @@ -315,7 +315,7 @@ def run_busy_loop(self): # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. - while not self.scheduler.has_unfinished_requests(): + while not self.scheduler.has_requests(): logger.debug("EngineCore busy loop waiting.") req = self.input_queue.get() self._handle_client_request(*req) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index dc3ad402e066..7ba83ba7b3bf 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -80,3 +80,13 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + + +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, +) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 519f38cb0b72..9a630f604177 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,7 +31,8 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -867,6 +868,9 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -1013,7 +1017,7 @@ def execute_model( spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids) - model_runner_output = ModelRunnerOutput( + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, @@ -1021,7 +1025,6 @@ def execute_model( logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) - return model_runner_output def generate_draft_token_ids( self, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f9a3217fbef3..868935ad6928 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -29,7 +29,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, + ModelRunnerOutput) from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -547,6 +548,9 @@ def execute_model( ) -> ModelRunnerOutput: # Update cached state self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT if self.is_multimodal_model: # Run the multimodal encoder if any. From 952d8372bfa590fd39331579f8e9a366551107ba Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 6 Mar 2025 18:48:25 -0800 Subject: [PATCH 2/3] Address review comments; adjust iteration stats Signed-off-by: Nick Hill --- vllm/v1/core/scheduler.py | 5 ++++- vllm/v1/engine/async_llm.py | 6 +++--- vllm/v1/engine/core.py | 2 ++ vllm/v1/metrics/loggers.py | 12 ++++++++---- vllm/v1/outputs.py | 2 +- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 71d2849c330b..d9a9583a1d4f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -625,7 +625,8 @@ def finish_requests( assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): request_ids = (request_ids, ) - request_ids = set(request_ids) + else: + request_ids = set(request_ids) for req_id in request_ids: request = self.requests.get(req_id) @@ -661,6 +662,8 @@ def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 def has_requests(self): + """Returns True if there are unfinished requests, or finished requests + not yet returned in SchedulerOutputs.""" return self.has_unfinished_requests() or self.has_finished_requests() def get_num_unscheduled_requests(self) -> int: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4c9d4cb467ae..1277e71d23da 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -255,13 +255,14 @@ async def _run_output_handler(self): while True: # 1) Pull EngineCoreOutputs from the EngineCore. outputs = await self.engine_core.get_output_async() + num_outputs = len(outputs.outputs) - iteration_stats = IterationStats() if self.log_stats else None + iteration_stats = IterationStats() if ( + self.log_stats and num_outputs) else None # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. - num_outputs = len(outputs.outputs) if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: slices = (outputs.outputs, ) else: @@ -315,7 +316,6 @@ def _record_stats( return assert scheduler_stats is not None - assert iteration_stats is not None for stat_logger in self.stat_loggers: stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9d22fd0c70fe..0c6bbd2a18c4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -146,6 +146,8 @@ def abort_requests(self, request_ids: list[str]): def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" + # Check for any requests remaining in the scheduler - unfinished, + # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 44493709b639..fcb4d4f5a25a 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -22,7 +22,7 @@ class StatLoggerBase(ABC): @abstractmethod def record(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + iteration_stats: Optional[IterationStats]): ... def log(self): # noqa @@ -56,10 +56,11 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: return float(np.sum(tracked_stats) / (now - self.last_log_time)) def record(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" - self._track_iteration_stats(iteration_stats) + if iteration_stats: + self._track_iteration_stats(iteration_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) @@ -319,7 +320,7 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): info_gauge.set(1) def record(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + iteration_stats: Optional[IterationStats]): """Log to prometheus.""" self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) @@ -331,6 +332,9 @@ def record(self, scheduler_stats: SchedulerStats, self.counter_gpu_prefix_cache_hits.inc( scheduler_stats.prefix_cache_stats.hits) + if iteration_stats is None: + return + self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs) self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens.inc( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 7ba83ba7b3bf..edae654b5d33 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -89,4 +89,4 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, -) \ No newline at end of file +) From c62bfb818bad76b24e7f83a5fae6abe91eee90f6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 7 Mar 2025 08:37:29 -0800 Subject: [PATCH 3/3] Adjust test Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 3880a3dd9b8a..e646ccbd4603 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict): engine_core_outputs = client.get_output().outputs if len(engine_core_outputs) == 0: - break + continue all_finished = True for out in engine_core_outputs: @@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict): engine_core_outputs = (await client.get_output_async()).outputs if len(engine_core_outputs) == 0: - break + continue all_finished = True for out in engine_core_outputs: