diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py new file mode 100644 index 00000000000..8d4425b6fb8 --- /dev/null +++ b/benchmarks/benchmark_long_document_qa.py @@ -0,0 +1,258 @@ +""" +Benchmark the efficiency of prefix caching. + +This script allows you to benchmark the performance of +a model with prefix-caching or cpu-offloading using fixed prompts + +Fixed example usage: + # This command run the vllm with 50GB CPU memory for offloading + # The workload samples 8 different prompts with a default input + # length of 20010 tokens, then replicates each prompt 2 times. + python benchmark_long_document_qa.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --block-allocator CpuOffloadingBlockAllocator \ + --num-documents 8 \ + --repeat-count 2 \ + --cpu-memory-gb 50 + +Commandline arguments: + + # Basic arguments + --model: The model to use for the benchmark. + + --enable-prefix-caching: Enable prefix caching or not. + + --block-allocator: The block allocator that vLLM uses. + - CpuGpuBlockAllocator: The default block allocator. + - CpuOffloadingBlockAllocator: The block allocator that supports + cpu offloading + + --gpu-memory-utilization: GPU memory utilization for vLLM. + + --cpu-memory-gb: The amount of CPU memory (GB) that is used by vLLM. + NOTE: CPU memory should be larger than GPU KV cache size when + using CpuOffloadingBlockAllocator. + + # Workload-related arguments + --num-documents: The number of documents to sample prompts from. + + --repeat-count: The number of times to repeat each prompt. + + # Other functionality + --seed: Random seed for reproducibility. + + --profile-swap-blocks: Profile the swap_blocks function in the custom ops. +""" + +import random +import time + +import torch + +from vllm import LLM, SamplingParams +from vllm.utils import FlexibleArgumentParser + +execution_times = {} + + +def build_result_dict(start_time, end_time, *args): + total_time = end_time - start_time + length = -1 + if len(args) > 1 and isinstance(args[1], torch.Tensor): + length = len(args[1]) + + return { + "start_time": start_time, + "total_time": total_time, + "swap_len": length + } + + +def timing_decorator(func): + + def wrapper(*args, **kwargs): + global execution_times + torch.cuda.synchronize() + start_time = time.time() # Record the start time + result = func(*args, **kwargs) # Call the wrapped function + torch.cuda.synchronize() + end_time = time.time() # Record the end time + if func.__name__ not in execution_times: + execution_times[func.__name__] = [] + + res = build_result_dict(start_time, end_time, *args) + execution_times[func.__name__].append(res) + return result # Return the result of the original function + + return wrapper + + +def process_timing_results(): + global execution_times + for key in execution_times: + len_to_time = {} + len_to_count = {} + for item in execution_times[key]: + swap_len = item["swap_len"] + if swap_len not in len_to_time: + len_to_time[swap_len] = 0 + len_to_time[swap_len] += item["total_time"] + + if swap_len not in len_to_count: + len_to_count[swap_len] = 0 + len_to_count[swap_len] += 1 + + for swap_len in len_to_time: + total_time = len_to_time[swap_len] + count = len_to_count[swap_len] + print(f"{key} on {swap_len} pages: " + f"{(count * swap_len) / total_time} pages per second") + + +def test_long_document_qa(llm=None, sampling_params=None, prompts=None): + + start_time = time.time() + llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + print(f"cost time {end_time - start_time}") + + +def repeat_prompts(prompts, repeat_count): + repeated_prompts = prompts * repeat_count + random.shuffle(repeated_prompts) + return repeated_prompts + + +def main(args): + if args.profile_swap_blocks: + from vllm.worker.cache_engine import CacheEngine + CacheEngine.swap_out = timing_decorator(CacheEngine.swap_out) + CacheEngine.swap_in = timing_decorator(CacheEngine.swap_in) + + random.seed(args.seed) + + # append the document id at the beginning to avoid any of the document + # being the prefix of other documents + prompts = [ + str(i) + ' '.join(['hi'] * args.document_length) + for i in range(args.num_documents) + ] + + preemption_mode = "" + if args.block_allocator == "CpuOffloadingBlockAllocator": + preemption_mode = "recompute" + else: + preemption_mode = "swap" + + llm = LLM(model=args.model, + tokenizer_mode='auto', + trust_remote_code=True, + enforce_eager=True, + tensor_parallel_size=args.tensor_parallel_size, + enable_prefix_caching=args.enable_prefix_caching, + block_allocator=args.block_allocator, + preemption_mode=preemption_mode, + swap_space=args.cpu_memory_gb, + enable_chunked_prefill=False, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=30000) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + + prompts = repeat_prompts(prompts, args.repeat_count) + + print("------warm up------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + random.shuffle(prompts) + + print("------start generating------") + test_long_document_qa( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + if args.profile_swap_blocks: + process_timing_results() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description= + 'Benchmark the performance with or without automatic prefix caching.') + parser.add_argument( + '--model', + type=str, + # this test aims to test long document QA capability, + # so we use llama 3.1 8B as it can process long context + default='meta-llama/Llama-3.1-8B') + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--repeat-count', + type=int, + default=2, + help='Number of times to repeat each prompt') + parser.add_argument( + '--document-length', + type=int, + # Roughly the number of tokens for a system paper, + # excluding images + default=20010, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument('--num-documents', + type=int, + default=8, + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument("--seed", + type=int, + default=0, + help='Random seed for reproducibility') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='GPU memory utilization for vLLM. Should be a ' + 'float point number ranging from 0 to 1. For this ' + 'test please use a small value so that the GPU ' + 'cannot hold all KV caches of all documents, ' + 'and the effect of CPU offloading can be tested.') + parser.add_argument( + '--cpu-memory-gb', + type=float, + default=1, + help="The amount of CPU memory (GB) that is used by vLLM. Not very " + "useful for CpuGpuBlockAllocator, but useful for " + "CpuOffloadingBlockAllocator to have more CPU KV cache space") + parser.add_argument( + '--block-allocator', + type=str, + default='CpuGpuBlockAllocator', + choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], + help='The block allocator that vLLM uses. Currently' + ' can be CpuGpuBlockAllocator (the default) and ' + 'CpuOffloadingBlockAllocator (experimental) that ' + 'supports offloading the KV cache to CPU . ' + 'When using CpuOffloadingBlockAllocator, the ' + 'preemption mode must be recompute.') + + parser.add_argument( + '--profile-swap-blocks', + action='store_true', + help='Profile the swap_blocks function in the custom ops') + + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 5e9381f712e..9a8ecae7b65 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -244,4 +244,4 @@ def main(args): parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f9a2..dea40c1904a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -11,6 +11,7 @@ #include "quantization/fp8/nvidia/quant_utils.cuh" #endif +#include #include #include #include @@ -21,8 +22,64 @@ typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { +namespace vllm { + +template +__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src, + ACC_T src_to_dst, const int num_pages, + const int num_elements_per_page) { + const int64_t srcPageIdx = src_to_dst[blockIdx.x][0]; + const int64_t dstPageIdx = src_to_dst[blockIdx.x][1]; + + const int64_t srcPageOffset = srcPageIdx * num_elements_per_page; + const int64_t dstPageOffset = dstPageIdx * num_elements_per_page; + + for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) { + dst[dstPageOffset + i] = src[srcPageOffset + i]; + } +} + +} // namespace vllm + +template +void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src, + const torch::Tensor& block_mapping, + const int num_blocks, + const int block_size_in_bytes) { + c10::cuda::CUDAGuard device_guard(block_mapping.device()); + auto block_mapping_accessor = + block_mapping.packed_accessor32(); + + int num_threads = 1024; + int grid_size = num_blocks; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::paged_copy<<>>( + dst, src, block_mapping_accessor, num_blocks, + block_size_in_bytes / DTYPE_LEN); +} + +template +T* get_kernel_ptr(torch::Tensor& tensor) { + // Get the kernel-accessible pointer of the given type T + // Returns NULL if the tensor is on CPU and non-pinned + torch::Device device = tensor.device(); + if (device.is_cuda()) { + return static_cast(tensor.data_ptr()); + } else if (device.is_cpu() && tensor.is_pinned()) { + T* ptr; + cudaHostGetDevicePointer((void**)&ptr, static_cast(tensor.data_ptr()), + 0); + return ptr; + } else if (device.is_cpu()) { + return NULL; + } else { + TORCH_CHECK(false, "Invalid device"); + } +} + +void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -62,6 +119,41 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, } } +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + int64_t* src_ptr = get_kernel_ptr(src); + int64_t* dst_ptr = get_kernel_ptr(dst); + if (src_ptr == NULL || dst_ptr == NULL) { + // fall back to the slow implementation + swap_blocks_slow(src, dst, block_mapping.cpu()); + } else { + // Check the device + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + torch::Device block_mapping_device = block_mapping.device(); + TORCH_CHECK(block_mapping_device.is_cuda(), "block_mapping must be on GPU"); + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + } + if (src_device.is_cuda()) { + TORCH_CHECK(src_device.index() == block_mapping_device.index(), + "src and block_mapping must be on the same GPU"); + } + if (dst_device.is_cuda()) { + TORCH_CHECK(dst_device.index() == block_mapping_device.index(), + "src and block_mapping must be on the same GPU"); + } + + const int64_t num_blocks = block_mapping.size(0); + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + + launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr, + block_mapping, num_blocks, + block_size_in_bytes); + } +} + namespace vllm { // Grid: (num_layers, num_pairs) diff --git a/tests/core/block/test_cpu_offloading_block_allocator.py b/tests/core/block/test_cpu_offloading_block_allocator.py new file mode 100644 index 00000000000..df4dbc40f12 --- /dev/null +++ b/tests/core/block/test_cpu_offloading_block_allocator.py @@ -0,0 +1,139 @@ +import pytest + +from vllm.core.block.cpu_offloading_block_allocator import ( + CpuOffloadingBlockAllocator) +from vllm.utils import Device, chunk_list + + +@pytest.mark.parametrize("num_cpu_blocks", [1024]) +@pytest.mark.parametrize("num_gpu_blocks", [256]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("allocator_type", ["prefix_caching"]) +def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): + allocator = CpuOffloadingBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + gpu_blocks = [ + allocator.allocate_mutable_block(prev_block=None, device=Device.GPU) + for _ in range(num_gpu_blocks) + ] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert len(allocator._uncached_blocks) == num_gpu_blocks + + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(0.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 + assert len(allocator._uncached_blocks) == num_gpu_blocks + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 + assert len(allocator._uncached_blocks) == 0 + + +@pytest.mark.parametrize("num_cpu_blocks", [1024]) +@pytest.mark.parametrize("num_gpu_blocks", [256]) +@pytest.mark.parametrize("block_size", [2]) +@pytest.mark.parametrize("allocator_type", ["prefix_caching"]) +def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): + allocator = CpuOffloadingBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + unique_token_ids = list( + range((num_cpu_blocks + num_gpu_blocks) * block_size)) + gpu_token_ids = list( + chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size)) + gpu_token_ids2 = list( + chunk_list( + unique_token_ids[num_gpu_blocks * block_size:2 * num_gpu_blocks * + block_size], block_size)) + + gpu_blocks = [ + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids + ] + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert len(allocator._uncached_blocks) == num_gpu_blocks + + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(0.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 + assert len(allocator._uncached_blocks) == num_gpu_blocks + + allocator.mark_blocks_as_computed([block.block_id for block in gpu_blocks]) + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + assert len(blocks_to_swap_out) + len(blocks_to_swap_in) == num_gpu_blocks + assert len(allocator._uncached_blocks) == 0 + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(1.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 + assert len(allocator._uncached_blocks) == 0 + + # allocate another gpu sequence to flush out the GPU cache + gpu_blocks = [ + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids2 + ] + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert all([ + not allocator._allocators[Device.GPU].block_is_computed(block.block_id) + for block in gpu_blocks + ]) + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(2.0) + assert len(blocks_to_swap_out) == 0 + assert len(blocks_to_swap_in) == 0 + assert len(allocator._uncached_blocks) == 0 + + # allocate original gpu sequence. It should hit CPU cache. + gpu_blocks = [ + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids + ] + + delta = num_cpu_blocks - num_gpu_blocks + assert allocator.get_num_free_blocks(Device.CPU) == delta + assert allocator.get_num_free_blocks(Device.GPU) == 0 + assert all([ + allocator._allocators[Device.GPU].block_is_computed(block.block_id) + for block in gpu_blocks + ]) + + blocks_to_swap_out, blocks_to_swap_in = allocator.get_and_reset_swaps(3.0) + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed51e2..ef90c36dd81 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -362,7 +362,7 @@ def test_swap_blocks( block_mapping = list(zip(src_blocks, dst_blocks)) block_mapping_tensor = torch.tensor(block_mapping, dtype=torch.int64, - device="cpu").view(-1, 2) + device=device).view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( diff --git a/vllm/config.py b/vllm/config.py index 12ed80c366e..8ab11e2dc2c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -838,6 +838,7 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + block_allocator: str = "CpuGpuBlockAllocator", ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -848,6 +849,7 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb + self.block_allocator = block_allocator self._verify_args() self._verify_cache_dtype() @@ -868,6 +870,13 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.block_allocator not in [ + "CpuGpuBlockAllocator", "CpuOffloadingBlockAllocator" + ]: + raise ValueError( + "Only CpuGpuBlockAllocator and CpuOffloadingBlockAllocator is " + f"supported. Got {self.block_allocator}.") + def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3a57487a6cd..75b49fb360d 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -339,17 +339,24 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: assert device in self._allocators return self._allocators[device].get_prefix_cache_hit_rate() - def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every schedule when BlockManagerV2 become default. Currently not useful. + + Args: + now (float): The time stamp. Returns: - List[Tuple[int, int]]: A mapping of source to destination block IDs. + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. """ - mapping = self._swap_mapping.copy() self._swap_mapping.clear() - return list(mapping.items()) + # return an empty list, to keep compatibility with previous behavior + return [], [] def find_cached_blocks_prefix( self, diff --git a/vllm/core/block/cpu_offloading_block_allocator.py b/vllm/core/block/cpu_offloading_block_allocator.py new file mode 100644 index 00000000000..a07289536c6 --- /dev/null +++ b/vllm/core/block/cpu_offloading_block_allocator.py @@ -0,0 +1,400 @@ +"""This file implement a block allocator that supports CPU KV cache offloading + +The key idea of this implementation is to maintain those allocated blocks +that didn't hit the cache, and constantly copy them into CPU after each +scheduler step. + +This idea is borrowed from ConServe +(paper link: https://arxiv.org/abs/2410.01228), based on the assumption +that the CPU-GPU bandwidth is much higher than GPU KV cache generation +throughput. Thanks Yifan for this idea. + +This implementation also allows vLLM to gracefully handle preemption by +recomputation. +""" +from collections import deque +from typing import Deque, Dict, List, Optional, Tuple + +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.utils import Device + + +class CpuOffloadingBlockAllocator(CpuGpuBlockAllocator): + """A block allocator that supports CPU KV cache offloading + + This class extends the `CpuGpuBlockAllocator` so that the CPU can be used + for prefix caching. + + It will internally maintain uncached blocks, and trying to copy uncached + blocks into CPU upon the end of scheduler step (i.e. calling + `get_and_reset_swaps`). + + This implementation also allows vLLM to gracefully handle preemption by + recomputation. + """ + + allocators: Dict[Device, PrefixCachingBlockAllocator] + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Initiate CpuOffloadingBlockAllocator. Similar to + CpuGpuBlockAllocator.create() but only support prefix caching + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuOffloadingBlockAllocator instance + with the specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + assert num_gpu_blocks < num_cpu_blocks, "CPU offloading block "\ + "allocator requires the allocated CPU memory capacity to be larger"\ + " than GPU memory capacity." + block_ids = list(range(num_gpu_blocks + num_cpu_blocks)) + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + assert allocator_type == "prefix_caching", "CpuOffloadingBlock"\ + "Allocator should be only used together with prefix caching." + + # prefix caching block is now the default. + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + + return CpuOffloadingBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__(self, cpu_block_allocator: PrefixCachingBlockAllocator, + gpu_block_allocator: PrefixCachingBlockAllocator): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + super().__init__(cpu_block_allocator, gpu_block_allocator) + self._allocators: Dict[Device, + PrefixCachingBlockAllocator] = { # type: ignore + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator + } + """ + GPU block should only be in one of the following three status: + uncached: allocated blocks that didn't hit any cache + cached: allocated blocks that are cached, either in GPU or in CPU + free: the blocks are not allocated by block allocator + This implementation aims to transform uncached blocks to cached blocks + by performing GPU to CPU copy when calling `get_and_reset_swaps` + + As block allocator will automatically track free blocks, and we don't + need to specially handle cached blocks. So we only track uncached blocks + """ + self._uncached_blocks: Deque[Block] = deque() + """ + We probe CPU cache hit by trying to allocate a CPU + block and see if it is computed. + If we hit the CPU cache, we cannot free this CPU block until the end + of scheduler step, in order to avoid the CPU cache being overwritten. + so we track the cpu blocks we allocated, and free it after scheduler + step (i.e. calling `get_and_reset_swaps`). + """ + self._allocated_cpu_blocks: Deque[Block] = deque() + + self.num_gpu_blocks = gpu_block_allocator.get_num_total_blocks() + self.num_cpu_blocks = cpu_block_allocator.get_num_total_blocks() + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated mutable block. + """ + assert device == Device.GPU, "Calls to CPU offloading block allocator "\ + "should always use Device.GPU --- CPU offloading block allocator "\ + "handles CPU offloading internally."\ + # mark this block as uncached + + block = self._allocators[device].allocate_mutable_block( + prev_block, extra_hash=extra_hash) + self._uncached_blocks.append(block) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + assert device == Device.GPU, "Calls to CPU offloading block allocator "\ + "should always use Device.GPU --- CPU offloading block allocator"\ + "handles CPU offloading internally." + + # repeatedly call allocate_immutable_block + # because it handles CPU-GPU offloading related logics. + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + blocks.append(prev_block) + return blocks + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + + assert device == Device.GPU, "Calls to CPU offloading block allocator"\ + " should always use Device.GPU --- CPU offloading block allocator"\ + " handles CPU offloading internally." + + # allocate a GPU block + block = self._allocators[device].allocate_immutable_block( + prev_block, token_ids, extra_hash=extra_hash) + block_id = block.block_id + assert block_id is not None + block_computed = self._allocators[device].block_is_computed(block_id) + + # deal with prefix caching, three cases in total: + # 1. cache hit on GPU + # 2. no cache hit on GPU but cache hit on CPU + # 3. no cache hit + if block_computed: + # cache hit on GPU, no need to put it into uncached blocks + pass + else: + # check if we can hit cache on CPU by trying to allocate CPU block + cpu_block = self._allocators[Device.CPU].allocate_immutable_block( + prev_block, token_ids, extra_hash=extra_hash) + cpu_block_id = cpu_block.block_id + assert cpu_block_id is not None + cpu_block_computed = self._allocators[ + Device.CPU].block_is_computed(cpu_block_id) + if cpu_block_computed: + # CPU cache hit + # mark the GPU block as computed + self._allocators[Device.GPU].mark_blocks_as_computed( + [block_id]) + # copy the CPU cache to GPU + self._swap_mapping[cpu_block_id] = block_id + # and don't free this block until `get_and_reset_swap` is called + self._allocated_cpu_blocks.append(cpu_block) + else: + # No cache hit + # mark the GPU block as uncached + self._uncached_blocks.append(block) + # and free cpu block + self._allocators[Device.CPU].free(cpu_block) + + return block + + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + + raise NotImplementedError("CPU offloading block allocator only " + "support preemption by recomputation.") + + def _is_gpu_block(self, block_id: int) -> bool: + return block_id in self._allocators[Device.GPU].all_block_ids + + def _is_gpu_block_unsafe(self, block_id: int) -> bool: + """Faster version of `_is_gpu_block` that doesn't check the block ID. + But assumes the that the block IDs are assigned contiguously, with GPU + block IDs coming before the CPU block IDs. + """ + return block_id < self.num_gpu_blocks + + def _get_physical_block_id_unsafe(self, block_id: int) -> int: + """Returns the physical block ID of the given block ID. + + This function avoids using the `allocator.get_physical_block_id()` + which is slow (O(NlogN)). Instead, this is based on the assumption + that the block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + + Args: + block_id (int): The block ID to get the physical block ID of. + + Returns: + int: The physical block ID of the given block ID. + + Note: + Please see the implementation of + `CpuOffloadingBlockAllocator.create` for how the block IDs are + assigned. + """ + if self._is_gpu_block_unsafe(block_id): + return block_id + else: + return block_id - self.num_gpu_blocks + + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called right before scheduler step finishes. + + This function will do the following things: + 1. Iterate over uncached blocks and see if we can copy it to CPU + 2. Update all allocated CPU block time stamp + 3. Free CPU blocks + 4. Return and clear all swapping status + + Args: + now (float): The time stamp used to update CPU access time, so + that CPU evictor can work. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. + """ + + allocator = self._allocators[Device.GPU] + cpu_allocator = self._allocators[Device.CPU] + + new_uncached_blocks: Deque[Block] = deque() + + while self._uncached_blocks: + block = self._uncached_blocks.pop() + block_id = block.block_id + + # check if this block is freed + if block_id is None: + # this block is already freed, no longer need to copy it to CPU + continue + + refcount = allocator._refcounter.get(block_id) + assert refcount > 0, "A freed block should have block_id None" + + # check if this block is computed + computed = allocator.block_is_computed(block_id) + if computed: # This block is computed, copy it to CPU + # allocate a block on CPU + cpu_block = cpu_allocator.allocate_immutable_block( + prev_block=block.prev_block, + token_ids=block.token_ids, + extra_hash=block.extra_hash, + ) + assert cpu_block.block_id is not None + self._allocated_cpu_blocks.append(cpu_block) + + # mark CPU block as computed + cpu_allocator.mark_blocks_as_computed([cpu_block.block_id]) + + # copy the GPU block to CPU + assert cpu_block.block_id is not None + self._swap_mapping[block_id] = cpu_block.block_id + + continue + + # this block is neither freed nor computed + # keep marking it as uncached + new_uncached_blocks.append(block) + + # update uncached blocks + self._uncached_blocks = new_uncached_blocks + + # iterate over allocated CPU blocks, update access time and free them + # need to update access time so that CPU evictor can work + while self._allocated_cpu_blocks: + cpu_block = self._allocated_cpu_blocks.pop() + assert cpu_block.block_id is not None + # update the access time + cpu_allocator.mark_blocks_as_accessed([cpu_block.block_id], now) + # free the block + cpu_allocator.free(cpu_block) + + # populate the swap_out list and swap_in list + blocks_to_swap_out = [] + blocks_to_swap_in = [] + for src, dst in self._swap_mapping.items(): + # only two possible cases: CPU -> GPU, or GPU -> CPU + #if src in self._allocators[Device.GPU].all_block_ids: + if self._is_gpu_block_unsafe(src): + # swap out + src = self._get_physical_block_id_unsafe(src) + dst = self._get_physical_block_id_unsafe(dst) + blocks_to_swap_out.append((src, dst)) + else: + # swap in + src = self._get_physical_block_id_unsafe(src) + dst = self._get_physical_block_id_unsafe(dst) + blocks_to_swap_in.append((src, dst)) + self._swap_mapping.clear() + return blocks_to_swap_out, blocks_to_swap_in + + def will_swap_in_cpu_blocks(self): + """Check if there are CPU blocks that will be swapped in + + Returns: + bool: True if there are CPU blocks that will be swapped in, False + otherwise. + """ + return bool(self._swap_mapping) diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 985a1098b6c..3029f8837c5 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -304,3 +304,21 @@ def find_cached_blocks_prefix( device: Device = Device.GPU, ) -> List[int]: pass + + @abstractmethod + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Args: + now (float): The time stamp. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. + """ + pass diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 1238303234d..7879be61f66 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -930,6 +930,8 @@ def __init__( # `get_num_cached_tokens` for more details. self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + self._seq_id_has_cpu_blocks: Set[int] = set() + def _update_seq_hashes(self, seq: Sequence) -> None: """Incrementally update the sequence's block hashes and record them.""" assert self._enable_caching @@ -991,7 +993,8 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: # TODO(rickyx): This hack could be removed once we mark blocks as # computed correctly with chunked prefills. - if num_computed_tokens_prev is not None and seq.is_prefill(): + if num_computed_tokens_prev is not None and seq.is_prefill() \ + and seq.seq_id not in self._seq_id_has_cpu_blocks: # For a sequence that is still in prefill, we don't # recompute the number of cached tokens. # This also handles correctly chunked prefill since currently @@ -1009,6 +1012,14 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens return num_cached_tokens + def on_swap_in_cpu_blocks(self, seq_id: int) -> None: + """Mark the sequence as having CPU blocks swapped in.""" + # NOTE(Yihua): This is a temporary solution to handle the case where + # the CPU offloading is enabled and the sequence has CPU blocks swapped + # in. In this case, the number in self._seq_id_to_num_tokens_computed + # should be invalidated and we need to re-compute it. + self._seq_id_has_cpu_blocks.add(seq_id) + def remove_seq(self, seq_id: int) -> None: """Stop tracking the sequence.""" if not self._enable_caching: @@ -1019,6 +1030,8 @@ def remove_seq(self, seq_id: int) -> None: assert seq_id in self._seq_id_to_num_tokens_computed del self._seq_id_to_num_tokens_computed[seq_id] + self._seq_id_has_cpu_blocks.discard(seq_id) + class LastAccessBlocksTracker: """Manages the last access time of the tracked sequences, in order to allow diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index b41e8482218..3d84d29dff1 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,6 +5,8 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.cpu_offloading_block_allocator import ( + CpuOffloadingBlockAllocator) from vllm.core.block.interfaces import Block from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) @@ -16,6 +18,11 @@ SeqId = int EncoderSeqId = str +block_allocator_creator = { + "CpuGpuBlockAllocator": CpuGpuBlockAllocator.create, + "CpuOffloadingBlockAllocator": CpuOffloadingBlockAllocator.create, +} + class SelfAttnBlockSpaceManager(BlockSpaceManager): """BlockSpaceManager which manages the allocation of KV cache. @@ -65,6 +72,7 @@ def __init__( watermark: float = 0.01, sliding_window: Optional[int] = None, enable_caching: bool = False, + block_allocator: str = "CpuGpuBlockAllocator", ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks @@ -90,7 +98,7 @@ def __init__( self.watermark_blocks = int(watermark * num_gpu_blocks) - self.block_allocator = CpuGpuBlockAllocator.create( + self.block_allocator = block_allocator_creator[block_allocator]( allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, @@ -159,6 +167,13 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_table.allocate(token_ids=seq.get_token_ids(), extra_hash=extra_hash) + # If the block allocator is CpuOffloadingBlockAllocator, we need to + # tell the computed_blocks_tracker to invalidate the previous computed + # num cached tokens + if isinstance(self.block_allocator, CpuOffloadingBlockAllocator) and \ + self.block_allocator.will_swap_in_cpu_blocks(): + self._computed_blocks_tracker.on_swap_in_cpu_blocks(seq.seq_id) + return block_table def allocate(self, seq_group: SequenceGroup) -> None: @@ -514,3 +529,20 @@ def get_num_cached_tokens(self, seq: Sequence) -> int: cached in the block manager for the sequence. """ return self._computed_blocks_tracker.get_num_cached_tokens(seq) + + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. + + Args: + now (float): The time stamp. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. The block IDs are physical block + IDs and it's expected to be used by the cache engine directly. + """ + return self.block_allocator.get_and_reset_swaps(now) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b10b8d3f4a5..948b2b63643 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -125,3 +125,20 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: @abstractmethod def get_num_cached_tokens(self, seq: Sequence) -> int: pass + + @abstractmethod + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. + + Args: + now (float): The time stamp. + + Returns: + A tuple of two lists: (blocks_to_swap_out, blocks_to_swap_in). + Each list is a List[Tuple[int, int]], containing the mapping of + source to destination block IDs. + """ + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a47e5945185..73a1adb84c1 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -92,3 +92,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_num_cached_tokens(self, seq: Sequence) -> int: return 0 + + def get_and_reset_swaps(self, + now: float) -> Tuple[List[Tuple[int, int]], ...]: + return [], [] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3bc6becf09..31b55b388b6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -152,7 +152,9 @@ class SchedulerOutputs: def __post_init__(self): # Swap in and swap out should never happen at the same time. - assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + # NOTE(Kuntai): in CpuOffloadingBlockAllocator swap in and swap out + # will happen at the same time. So we comment out the following line. + # assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: @@ -358,7 +360,8 @@ def __init__( num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + enable_caching=self.cache_config.enable_prefix_caching, + block_allocator=self.cache_config.block_allocator) # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. @@ -1131,6 +1134,17 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_copy = running_scheduled.blocks_to_copy blocks_to_copy.extend(swapped_in.blocks_to_copy) + blocks_to_swap_in = swapped_in.blocks_to_swap_in + blocks_to_swap_out = running_scheduled.blocks_to_swap_out + + # NOTE(Kuntai): extend the swapping list for CPU offloading + new_swap_out, new_swap_in = \ + self.block_manager.get_and_reset_swaps(time.time()) + for src, dst in new_swap_out: + blocks_to_swap_out.extend((src, dst)) + for src, dst in new_swap_in: + blocks_to_swap_in.extend((src, dst)) + ignored_seq_groups = prefills.ignored_seq_groups ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) @@ -1139,8 +1153,8 @@ def _schedule_default(self) -> SchedulerOutputs: num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -1209,6 +1223,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + blocks_to_swap_in = swapped_in.blocks_to_swap_in + blocks_to_swap_out = running_scheduled.blocks_to_swap_out + + # NOTE(Kuntai): extend the swapping list for CPU offloading + new_swap_out, new_swap_in = \ + self.block_manager.get_and_reset_swaps(time.time()) + for src, dst in new_swap_out: + blocks_to_swap_out.extend((src, dst)) + for src, dst in new_swap_in: + blocks_to_swap_in.extend((src, dst)) + # Put prefills first due to Attention backend ordering assumption. scheduled_seq_groups = (prefills.seq_groups + running_scheduled.prefill_seq_groups + @@ -1231,10 +1260,9 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=num_lookahead_slots, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0098648b1cd..3c7495a2209 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -112,6 +112,7 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None + block_allocator: str = "CpuGpuBlockAllocator" # NOTE(kzawora): default block size for Gaudi should be 128 # smaller sizes still work, but very inefficiently block_size: int = 16 if not current_platform.is_hpu() else 128 @@ -413,6 +414,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='If specified, use nsight to profile Ray workers.') # KV cache arguments + parser.add_argument( + '--block-allocator', + type=str, + default='CpuGpuBlockAllocator', + choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], + help='The block allocator that vLLM uses. Currently' + ' can be CpuGpuBlockAllocator (the default) and ' + 'CpuOffloadingBlockAllocator (experimental) that ' + 'supports offloading the KV cache to CPU . ' + 'When using CpuOffloadingBlockAllocator, the ' + 'preemption mode must be recompute.') parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, @@ -1015,6 +1027,14 @@ def create_engine_config(self, "CPU offload space must be non-negative" f", but got {self.cpu_offload_gb}") + if self.block_allocator == "CpuOffloadingBlockAllocator" and \ + self.preemption_mode == "swap": + raise ValueError( + "CpuOffloadingBlockAllocator only supports preemption by " + "recomputation as it internally offloads the request KV cache " + "to CPU. Please add `--preemption-mode recomputation` to vLLM " + "engine args") + device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() @@ -1037,6 +1057,7 @@ def create_engine_config(self, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, + block_allocator=self.block_allocator, ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 11b2574ce42..9d63e068998 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -164,6 +164,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, + block_allocator: str = "CpuGpuBlockAllocator", enforce_eager: Optional[bool] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, @@ -212,6 +213,7 @@ def __init__( gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, + block_allocator=block_allocator, enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a368bb9ee9a..d418740449f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -322,10 +322,10 @@ def prepare_worker_input( # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", + device="cuda", dtype=torch.int64).view(-1, 2) blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", + device="cuda", dtype=torch.int64).view(-1, 2) # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy`