Skip to content
16 changes: 16 additions & 0 deletions vllm/array_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from collections import defaultdict
from typing import Dict, List

import numpy as np

_POOL: Dict[int, List[np.ndarray]] = defaultdict(list)


def alloc_array(max_tokens: int) -> np.ndarray:
if max_tokens in _POOL and _POOL[max_tokens]:
return _POOL[max_tokens].pop()
return np.zeros((max_tokens, ), dtype=np.int64)


def del_array(arr: np.ndarray) -> None:
_POOL[len(arr)].append(arr)
88 changes: 44 additions & 44 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Optional

import numpy as np

from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
from vllm.utils import Device, cdiv, chunk_array


class BlockTable:
Expand Down Expand Up @@ -50,20 +52,22 @@ def __init__(
self._blocks: List[Block] = _blocks

self._max_block_sliding_window = max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())

_num_full_slots = 0
for block in self._blocks:
_num_full_slots += block.num_tokens
self._num_full_slots = _num_full_slots

@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
def get_num_required_blocks(token_ids: np.ndarray, block_size: int) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.

This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).

Args:
token_ids (List[int]): The sequence of token IDs to be stored.
token_ids (np.ndarray): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.

Expand All @@ -74,27 +78,28 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
return cdiv(len(token_ids), block_size)

def allocate(self,
token_ids: List[int],
token_ids: np.ndarray,
device: Device = Device.GPU) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.

This method allocates the required number of blocks to store the given
sequence of token IDs.

Args:
token_ids (List[int]): The sequence of token IDs to be stored.
token_ids (np.ndarray): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
"""
assert not self._is_allocated
assert token_ids
assert not self._blocks
len_token_ids = len(token_ids)
assert len_token_ids > 0
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self._num_full_slots = len(token_ids)
self._num_full_slots = len_token_ids

def append_token_ids(self,
token_ids: List[int],
token_ids: np.ndarray,
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
Expand All @@ -110,7 +115,7 @@ def append_token_ids(self,
separate block.

Args:
token_ids (List[int]): The sequence of token IDs to be appended.
token_ids (np.ndarray): The sequence of token IDs to be appended.
num_computed_slots (Optional[int]): The number of KV cache slots
that are already filled (computed).
When sliding window is enabled, this is used to compute how many
Expand All @@ -119,7 +124,7 @@ def append_token_ids(self,
Without chunked prefill, it should be the same as
_num_full_slots.
"""
assert self._is_allocated, "no blocks have been allocated"
assert self._blocks, "no blocks have been allocated"
assert len(self._blocks) > 0

# Drop blocks that are no longer needed due to sliding window
Expand Down Expand Up @@ -163,7 +168,7 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
# Currently the block table only supports
# appending tokens to GPU blocks.
device = Device.GPU
assert self._is_allocated
assert self._blocks

if self._num_empty_slots >= num_empty_slots:
return
Expand All @@ -190,8 +195,7 @@ def fork(self) -> "BlockTable":
BlockTable: A new BlockTable instance with a copy of the blocks from
the current instance.
"""
assert self._is_allocated
assert len(self._blocks) > 0
assert self._blocks
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
Expand All @@ -208,7 +212,6 @@ def free(self) -> None:
occupied by each block. After freeing all the blocks, the `_blocks` list
is set to `None`.
"""
assert self._is_allocated
for block in self._blocks:
self._allocator.free(block)
self._blocks = []
Expand All @@ -227,21 +230,21 @@ def physical_block_ids(self) -> List[Optional[int]]:
List[int]: A list of physical block indices for the blocks in the
BlockTable.
"""
assert self._is_allocated
return [block.block_id for block in self._blocks]

def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
def get_unseen_token_ids(self,
sequence_token_ids: np.ndarray) -> np.ndarray:
"""Get the number of "unseen" tokens in the sequence.

Unseen tokens are tokens in the sequence corresponding to this block
table, but are not yet appended to this block table.

Args:
sequence_token_ids (List[int]): The list of token ids in the
sequence_token_ids (np.ndarray): The list of token ids in the
sequence.

Returns:
List[int]: The postfix of sequence_token_ids that has not yet been
np.ndarray: The postfix of sequence_token_ids that has not yet been
appended to the block table.
"""

Expand All @@ -250,10 +253,11 @@ def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
return sequence_token_ids[self.num_full_slots:]

def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
token_ids: np.ndarray,
device: Device) -> List[Block]:
blocks: List[Block] = []
for block_token_ids in chunk_list(token_ids, self._block_size):
for i in range(0, len(token_ids), self._block_size):
block_token_ids = token_ids[i:i + self._block_size]
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
prev_block = self._allocator.allocate_immutable(
Expand All @@ -267,29 +271,23 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],

return blocks

def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids: List[int] = []

if not self._is_allocated:
return token_ids
def _get_all_token_ids(self) -> np.ndarray:
# NOTE: This function is O(seq_len); use only for testing.
token_id_arrays: List[np.ndarray] = []

for block in self._blocks:
token_ids.extend(block.token_ids)

return token_ids
if self._blocks:
for block in self._blocks:
token_id_arrays.append(block.token_ids)

@property
def _is_allocated(self) -> bool:
return len(self._blocks) > 0
return np.concatenate(token_id_arrays)

@property
def blocks(self) -> Optional[List[Block]]:
return self._blocks

@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
assert self._blocks
return len(self._blocks) * self._block_size - self._num_full_slots

@property
Expand All @@ -303,27 +301,29 @@ def num_full_slots(self) -> int:
return self._num_full_slots

def get_num_blocks_touched_by_append_slots(
self, token_ids: List[int], num_lookahead_slots: int) -> int:
self, token_ids: np.ndarray, num_lookahead_slots: int) -> int:
"""Determine how many blocks will be "touched" by appending the token
ids.

This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted.
"""

all_token_ids = token_ids + [-1] * num_lookahead_slots
token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
return len(token_blocks)
size = len(token_ids) + num_lookahead_slots
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
n_blocks = 1 + len(range(first_chunk_size, size, self._block_size))
return n_blocks

def _chunk_token_blocks_for_append(
self, token_ids: List[int]) -> List[List[int]]:
self, token_ids: np.ndarray) -> List[np.ndarray]:
"""Split the token ids into block-sized chunks so they can be easily
appended to blocks. The first such "token block" may have less token ids
than the block size, since the last allocated block may be partially
full.
"""
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
token_blocks = [token_ids[:first_chunk_size]] + chunk_array(
token_ids[first_chunk_size:], self._block_size)
return token_blocks
16 changes: 11 additions & 5 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, FrozenSet, List, Optional, Tuple

import numpy as np

from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
Expand Down Expand Up @@ -131,15 +133,15 @@ def allocate_mutable(self, prev_block: Optional[Block],
return self._allocators[device].allocate_mutable(prev_block)

def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
token_ids: np.ndarray, device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.

Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
token_ids (List[int]): The list of token IDs to be stored in the new
block.
token_ids (np.ndarray): The list of token IDs to be stored in the
new block.
device (Device): The device on which to allocate the new block.

Returns:
Expand Down Expand Up @@ -326,7 +328,7 @@ def __init__(self, proxy: Block):
super().__init__()
self._proxy = proxy

def append_token_ids(self, token_ids: List[BlockId]):
def append_token_ids(self, token_ids: np.ndarray):
raise ValueError("null block should not be modified")

@property
Expand All @@ -338,13 +340,17 @@ def block_id(self, value: Optional[BlockId]):
raise ValueError("null block should not be modified")

@property
def token_ids(self) -> List[BlockId]:
def token_ids(self) -> np.ndarray:
return self._proxy.token_ids

@property
def num_empty_slots(self) -> BlockId:
return self._proxy.num_empty_slots

@property
def num_tokens(self) -> BlockId:
return self._proxy.num_tokens

@property
def is_full(self):
return self._proxy.is_full
Expand Down
17 changes: 12 additions & 5 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple

import numpy as np

from vllm.utils import Device

BlockId = int
Expand All @@ -9,7 +11,7 @@
class Block(ABC):

@abstractmethod
def append_token_ids(self, token_ids: List[int]) -> None:
def append_token_ids(self, token_ids: np.ndarray) -> None:
pass

@property
Expand All @@ -25,14 +27,19 @@ def block_id(self, value: Optional[int]) -> None:

@property
@abstractmethod
def token_ids(self) -> List[int]:
def token_ids(self) -> np.ndarray:
pass

@property
@abstractmethod
def num_empty_slots(self) -> int:
pass

@property
@abstractmethod
def num_tokens(self) -> int:
pass

@property
@abstractmethod
def is_full(self) -> bool:
Expand Down Expand Up @@ -70,7 +77,7 @@ class Factory(Protocol):
def __call__(
self,
prev_block: Optional["Block"],
token_ids: List[int],
token_ids: np.ndarray,
block_size: int,
allocator: "BlockAllocator",
block_id: Optional[int] = None,
Expand All @@ -97,7 +104,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block:

@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
token_ids: np.ndarray) -> Block:
pass

@abstractmethod
Expand Down Expand Up @@ -180,7 +187,7 @@ def allocate_mutable(self, prev_block: Optional[Block],

@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
token_ids: np.ndarray, device: Device) -> Block:
pass

@abstractmethod
Expand Down
Loading