Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import cdiv, is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -232,6 +232,25 @@ def _construct_cached_request_state(req_id_suffix: int):
)


def setup_cpu_tensors(input_batch: InputBatch, kv_cache_config: KVCacheConfig):
block_table_cpu_tensor = torch.zeros(
(input_batch.max_num_reqs,
cdiv(input_batch.max_model_len,
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size)),
device="cpu",
dtype=torch.int32,
pin_memory=input_batch.pin_memory,
)
slot_mapping_cpu_tensor = torch.zeros(
input_batch.max_num_batched_tokens,
device="cpu",
dtype=torch.int64,
pin_memory=input_batch.pin_memory,
)
input_batch.block_table.block_tables[0].init_block_table_cpu(
block_table_cpu_tensor, slot_mapping_cpu_tensor)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
Expand All @@ -244,15 +263,17 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
kv_cache_config = get_kv_cache_config()
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
kv_cache_config=kv_cache_config,
)
setup_cpu_tensors(input_batch, kv_cache_config)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
Expand Down Expand Up @@ -334,24 +355,28 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
kv_cache_config = get_kv_cache_config()
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
kv_cache_config=kv_cache_config,
)
ref_kv_cache_config = get_kv_cache_config()
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
kv_cache_config=ref_kv_cache_config,
)
setup_cpu_tensors(input_batch, kv_cache_config)
setup_cpu_tensors(ref_input_batch, ref_kv_cache_config)

reqs: list[CachedRequestState] = []
req_id_reqs = {}
Expand Down
3 changes: 3 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

from .test_gpu_input_batch import setup_cpu_tensors


def initialize_kv_cache(runner: GPUModelRunner):
"""
Expand Down Expand Up @@ -45,6 +47,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
vocab_size=runner.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
setup_cpu_tensors(runner.input_batch, kv_cache_config)
runner.initialize_attn_backend(kv_cache_config)


Expand Down
19 changes: 7 additions & 12 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,19 @@ def __init__(
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)

def init_block_table_cpu(self, block_table_cpu_tensor: torch.Tensor,
slot_mapping_cpu_tensor: torch.Tensor):
self.block_table_cpu = block_table_cpu_tensor
self.block_table_np = self.block_table_cpu.numpy()
self.slot_mapping_cpu = slot_mapping_cpu_tensor
self.slot_mapping_np = self.slot_mapping_cpu.numpy()

def append_row(
self,
block_ids: list[int],
Expand Down
26 changes: 26 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,23 @@ def __init__(
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# NOTE(Chen): a temporary fix for quantization + cpu offload
# https://github.com/vllm-project/vllm/issues/18425
# Initialize block_table_cpu_tensor and slot_mapping_cpu_tensor
# before self.load_model can fix the issue.
# Need to investigate more on the root cause.
max_num_blocks_per_req = cdiv(self.max_model_len,
self.cache_config.block_size)
self.block_table_cpu_tensor = torch.zeros(
(self.max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=self.pin_memory,
)
self.slot_mapping_cpu_tensor = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)

def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
"""
Expand Down Expand Up @@ -1957,6 +1974,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
vocab_size=self.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)

# NOTE(Chen): a temporary fix for quantization + cpu offload
# https://github.com/vllm-project/vllm/issues/18425
# Initialize block_table_cpu_tensor and slot_mapping_cpu_tensor
# before self.load_model can fix the issue.
# Need to investigate more on the root cause.
assert len(self.input_batch.block_table.block_tables) == 1
self.input_batch.block_table.block_tables[0].init_block_table_cpu(
self.block_table_cpu_tensor, self.slot_mapping_cpu_tensor)
self.initialize_attn_backend(kv_cache_config)

kv_caches: dict[str, torch.Tensor] = {}
Expand Down