-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Description
TL;DR:
- The core speculative decoding logic is an algorithmic vLLM feature, and should not have different implementations for different backends.
- This RFC proposes a BaseWorker interface, so all workers are compatible with speculative decoding.
- This RFC proposes adding
init_workers
,init_cache
,profile_num_available_blocks
to the ExecutorBase interface - The modification to ExecutorBase will require common logic in NeuronExecutor, GPUExecutor, RayGPUExecutor to move to the LLMEngine.
Motivation
At a high level, all speculative decoding consists of three phases: propose tokens, score tokens, and verify tokens. There are various implementations for proposals and verification; we can swap out a draft model for prompt-lookup-decoding or rejection sampling for typical acceptance.
This level of abstraction (proposer, scorer, verifier) means that we can implement the speculative decoding framework above the Worker level, such that non-GPU workers can fit within the framework. This RFC proposes the interfaces necessary to make this happen.
Is speculative decoding the only feature that can live at this level?
Jump decoding is a technique which allows skipping generation of tokens when parts of the output can be known beforehand with 100% certainty. For example, given a large templated file, an LLM could fill out the templated parts while keeping the other parts unmodified.
This feature is not implemented in vLLM today, but is an example of a feature that would live at the same level as speculative decoding. In fact, jump decoding can be seen as a special case of speculative decoding -- but this is left as a future work / out of scope.
Interface modifications
BaseWorker
class WorkerBase(ABC):
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
@abstractmethod
def profile_num_available_blocks(self, block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: float,
cache_dtype: str) -> tuple[int, int]:
"""Profile the model on-device to determine the maximum number of KV
blocks that can be allocated.
Returns a tuple[num_device_blocks, num_cpu_blocks], where
num_device_blocks refers to the number of blocks in the "active" KV
cache (e.g. where blocks are appended to), and num_cpu_blocks refers
to the number of blocks in the "passive" KV cache (e.g. where blocks
are swapped to).
Examples:
- The GPUExecutor will return [num_gpu_blocks, num_cpu_blocks].
- A future CPUExecutor can return [num_cpu_blocks, 0] or
[num_cpu_blocks, num_swap_cpu_blocks].
"""
raise NotImplementedError
@abstractmethod
def init_cache(self, cache_config: CacheConfig) -> None:
"""Given a fully-specified cache config, initialize the KV cache. This
is separate from init_workers as profiling may be required to determine
the maxmimum allowed KV cache size.
"""
raise NotImplementedError
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
ExecutorBase
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index 55180d6110b..bd77e5eb892 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -28,6 +28,42 @@ class ExecutorBase(ABC):
) -> None:
raise NotImplementedError
+ @abstractmethod
+ def init_workers(self) -> None:
+ """Initialize workers, such as loading the model or preparing on-device
+ tensors.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def profile_num_available_blocks(self, block_size: int,
+ gpu_memory_utilization: float,
+ cpu_swap_space: float,
+ cache_dtype: str) -> tuple[int, int]:
+ """Profile the model on-device to determine the maximum number of KV
+ blocks that can be allocated.
+
+ Returns a tuple[num_device_blocks, num_cpu_blocks], where
+ num_device_blocks refers to the number of blocks in the "active" KV
+ cache (e.g. where blocks are appended to), and num_cpu_blocks refers
+ to the number of blocks in the "passive" KV cache (e.g. where blocks
+ are swapped to).
+
+ Examples:
+ - The GPUExecutor will return [num_gpu_blocks, num_cpu_blocks].
+ - A future CPUExecutor can return [num_cpu_blocks, 0] or
+ [num_cpu_blocks, num_swap_cpu_blocks].
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def init_cache(self, cache_config: CacheConfig) -> None:
+ """Given a fully-specified cache config, initialize the KV cache. This
+ is separate from init_workers as profiling may be required to determine
+ the maxmimum allowed KV cache size.
+ """
+ raise NotImplementedError
+
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Existing implementation modifications
The existing ExecutorBase implementations need to be modified to conform to the interface. The changes will look approximately like the following:
- The profiling orchestration logic in GPUExecutor and RayGPUExecutor will move back to the LLMEngine.
- The NeuronExecutor will override
profile_num_available_blocks
to return [max_num_seq, 0].