diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 12e79ff165f4..5cf963b6b0d6 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -3,6 +3,7 @@ import multiprocessing import os import pickle +import queue import signal import threading import time @@ -33,7 +34,8 @@ get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, + ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -412,6 +414,16 @@ def __init__( # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) + scheduler_config = vllm_config.scheduler_config + self.use_async_scheduling = scheduler_config.async_scheduling + if self.use_async_scheduling: + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy") + self.async_output_copy_thread.start() + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -593,6 +605,36 @@ class ResponseStatus(Enum): SUCCESS = auto() FAILURE = auto() + def enqueue_output(self, output: Any): + """Prepares output from the worker and enqueues it to the + worker_response_mq. If the output is an Exception, it is + converted to a FAILURE response. + """ + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() + + if isinstance(output, Exception): + result = (WorkerProc.ResponseStatus.FAILURE, str(output)) + else: + result = (WorkerProc.ResponseStatus.SUCCESS, output) + self.worker_response_mq.enqueue(result) + + def handle_output(self, output: Any): + """Handles output from the worker. If async scheduling is enabled, + it is passed to the async_output_busy_loop thread. Otherwise, it is + enqueued directly to the worker_response_mq. + """ + if self.use_async_scheduling: + self.async_output_queue.put(output) + else: + self.enqueue_output(output) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + output = self.async_output_queue.get() + self.enqueue_output(output) + def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: @@ -612,10 +654,8 @@ def worker_busy_loop(self): # exception might not be serializable, so we convert it to # string, only for logging purpose. if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + self.handle_output(e) continue if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + self.handle_output(output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3..1b2da8addb19 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import NamedTuple, Optional @@ -114,6 +115,20 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +# ModelRunnerOutput wrapper for async scheduling. +class AsyncModelRunnerOutput(ABC): + + @abstractmethod + def get_output(self) -> ModelRunnerOutput: + """Get the ModelRunnerOutput for this async output. + + This is a blocking call that waits until the results are ready, which + might involve copying device tensors to the host. + This method should only be called once per AsyncModelRunnerOutput. + """ + pass + + @dataclass class DraftTokenIds: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ad70d9efaaaa..83fc821b8494 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -250,6 +250,11 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42baf020e9dc..7859e966b04f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -67,8 +67,8 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - LogprobsTensors, ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -100,6 +100,53 @@ logger = init_logger(__name__) +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self._async_copy_ready_event = torch.cuda.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) + self._async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self._async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( @@ -230,6 +277,10 @@ def __init__( is_pooling_model=self.is_pooling_model, ) + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = torch.cuda.Stream() if \ + self.use_async_scheduling else None + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -654,6 +705,73 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= (prev_index == flattened_index) + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids_cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids.gpu[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, + 0], + non_blocking=True) + return + # Upload the index tensors asynchronously + # so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor(flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to( + self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0]) + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -740,7 +858,8 @@ def _prepare_inputs( max_seq_len = self.seq_lens.np[:num_reqs].max().item() # Copy the tensors to the GPU. - self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( @@ -1458,7 +1577,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -1673,6 +1792,12 @@ def execute_model( # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1685,21 +1810,41 @@ def execute_model( scheduler_output.num_scheduled_tokens, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. @@ -1707,7 +1852,12 @@ def execute_model( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1722,6 +1872,7 @@ def execute_model( start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) @@ -1741,9 +1892,9 @@ def execute_model( self.eplb_step() - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1752,6 +1903,16 @@ def execute_model( num_nans_in_logits=num_nans_in_logits, ) + if not self.use_async_scheduling: + return output + + return AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index affba877ecf9..99c805a3e949 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,7 +5,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed @@ -28,8 +28,8 @@ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -355,7 +355,7 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 if forward_pass and not get_pp_group().is_first_rank: @@ -365,7 +365,7 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - if isinstance(output, ModelRunnerOutput): + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output assert isinstance(output, IntermediateTensors)