diff --git a/vllm/array_pool.py b/vllm/array_pool.py new file mode 100644 index 000000000000..b3e9dbed7eb5 --- /dev/null +++ b/vllm/array_pool.py @@ -0,0 +1,16 @@ +from collections import defaultdict +from typing import Dict, List + +import numpy as np + +_POOL: Dict[int, List[np.ndarray]] = defaultdict(list) + + +def alloc_array(max_tokens: int) -> np.ndarray: + if max_tokens in _POOL and _POOL[max_tokens]: + return _POOL[max_tokens].pop() + return np.zeros((max_tokens, ), dtype=np.int64) + + +def del_array(arr: np.ndarray) -> None: + _POOL[len(arr)].append(arr) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index d705f3d91a07..5c68c3df186d 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -1,7 +1,9 @@ from typing import List, Optional +import numpy as np + from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator -from vllm.utils import Device, cdiv, chunk_list +from vllm.utils import Device, cdiv, chunk_array class BlockTable: @@ -50,12 +52,14 @@ def __init__( self._blocks: List[Block] = _blocks self._max_block_sliding_window = max_block_sliding_window - # Use helper method instead of directly calculating, as blocks - # may not be allocated. - self._num_full_slots = len(self._get_all_token_ids()) + + _num_full_slots = 0 + for block in self._blocks: + _num_full_slots += block.num_tokens + self._num_full_slots = _num_full_slots @staticmethod - def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: + def get_num_required_blocks(token_ids: np.ndarray, block_size: int) -> int: """Calculates the minimum number of blocks required to store a given sequence of token IDs. @@ -63,7 +67,7 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: allocation (e.g. ignoring prefix caching). Args: - token_ids (List[int]): The sequence of token IDs to be stored. + token_ids (np.ndarray): The sequence of token IDs to be stored. block_size (int): The maximum number of tokens that can be stored in a single block. @@ -74,7 +78,7 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: return cdiv(len(token_ids), block_size) def allocate(self, - token_ids: List[int], + token_ids: np.ndarray, device: Device = Device.GPU) -> None: """Allocates memory blocks for storing the given sequence of token IDs. @@ -82,19 +86,20 @@ def allocate(self, sequence of token IDs. Args: - token_ids (List[int]): The sequence of token IDs to be stored. + token_ids (np.ndarray): The sequence of token IDs to be stored. device (Device, optional): The device on which the blocks should be allocated. Defaults to Device.GPU. """ - assert not self._is_allocated - assert token_ids + assert not self._blocks + len_token_ids = len(token_ids) + assert len_token_ids > 0 self._blocks = self._allocate_blocks_for_token_ids(prev_block=None, token_ids=token_ids, device=device) - self._num_full_slots = len(token_ids) + self._num_full_slots = len_token_ids def append_token_ids(self, - token_ids: List[int], + token_ids: np.ndarray, num_lookahead_slots: int = 0, num_computed_slots: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the @@ -110,7 +115,7 @@ def append_token_ids(self, separate block. Args: - token_ids (List[int]): The sequence of token IDs to be appended. + token_ids (np.ndarray): The sequence of token IDs to be appended. num_computed_slots (Optional[int]): The number of KV cache slots that are already filled (computed). When sliding window is enabled, this is used to compute how many @@ -119,7 +124,7 @@ def append_token_ids(self, Without chunked prefill, it should be the same as _num_full_slots. """ - assert self._is_allocated, "no blocks have been allocated" + assert self._blocks, "no blocks have been allocated" assert len(self._blocks) > 0 # Drop blocks that are no longer needed due to sliding window @@ -163,7 +168,7 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: # Currently the block table only supports # appending tokens to GPU blocks. device = Device.GPU - assert self._is_allocated + assert self._blocks if self._num_empty_slots >= num_empty_slots: return @@ -190,8 +195,7 @@ def fork(self) -> "BlockTable": BlockTable: A new BlockTable instance with a copy of the blocks from the current instance. """ - assert self._is_allocated - assert len(self._blocks) > 0 + assert self._blocks forked_blocks = self._allocator.fork(self._blocks[-1]) return BlockTable( block_size=self._block_size, @@ -208,7 +212,6 @@ def free(self) -> None: occupied by each block. After freeing all the blocks, the `_blocks` list is set to `None`. """ - assert self._is_allocated for block in self._blocks: self._allocator.free(block) self._blocks = [] @@ -227,21 +230,21 @@ def physical_block_ids(self) -> List[Optional[int]]: List[int]: A list of physical block indices for the blocks in the BlockTable. """ - assert self._is_allocated return [block.block_id for block in self._blocks] - def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + def get_unseen_token_ids(self, + sequence_token_ids: np.ndarray) -> np.ndarray: """Get the number of "unseen" tokens in the sequence. Unseen tokens are tokens in the sequence corresponding to this block table, but are not yet appended to this block table. Args: - sequence_token_ids (List[int]): The list of token ids in the + sequence_token_ids (np.ndarray): The list of token ids in the sequence. Returns: - List[int]: The postfix of sequence_token_ids that has not yet been + np.ndarray: The postfix of sequence_token_ids that has not yet been appended to the block table. """ @@ -250,10 +253,11 @@ def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: return sequence_token_ids[self.num_full_slots:] def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: np.ndarray, device: Device) -> List[Block]: blocks: List[Block] = [] - for block_token_ids in chunk_list(token_ids, self._block_size): + for i in range(0, len(token_ids), self._block_size): + block_token_ids = token_ids[i:i + self._block_size] if len(block_token_ids) == self._block_size: # If the block is full, create an immutable block. prev_block = self._allocator.allocate_immutable( @@ -267,21 +271,15 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], return blocks - def _get_all_token_ids(self) -> List[int]: - # NOTE: This function is O(seq_len); use sparingly. - token_ids: List[int] = [] - - if not self._is_allocated: - return token_ids + def _get_all_token_ids(self) -> np.ndarray: + # NOTE: This function is O(seq_len); use only for testing. + token_id_arrays: List[np.ndarray] = [] - for block in self._blocks: - token_ids.extend(block.token_ids) - - return token_ids + if self._blocks: + for block in self._blocks: + token_id_arrays.append(block.token_ids) - @property - def _is_allocated(self) -> bool: - return len(self._blocks) > 0 + return np.concatenate(token_id_arrays) @property def blocks(self) -> Optional[List[Block]]: @@ -289,7 +287,7 @@ def blocks(self) -> Optional[List[Block]]: @property def _num_empty_slots(self) -> int: - assert self._is_allocated + assert self._blocks return len(self._blocks) * self._block_size - self._num_full_slots @property @@ -303,7 +301,7 @@ def num_full_slots(self) -> int: return self._num_full_slots def get_num_blocks_touched_by_append_slots( - self, token_ids: List[int], num_lookahead_slots: int) -> int: + self, token_ids: np.ndarray, num_lookahead_slots: int) -> int: """Determine how many blocks will be "touched" by appending the token ids. @@ -311,12 +309,14 @@ def get_num_blocks_touched_by_append_slots( continue generation, or if it must be preempted. """ - all_token_ids = token_ids + [-1] * num_lookahead_slots - token_blocks = self._chunk_token_blocks_for_append(all_token_ids) - return len(token_blocks) + size = len(token_ids) + num_lookahead_slots + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + n_blocks = 1 + len(range(first_chunk_size, size, self._block_size)) + return n_blocks def _chunk_token_blocks_for_append( - self, token_ids: List[int]) -> List[List[int]]: + self, token_ids: np.ndarray) -> List[np.ndarray]: """Split the token ids into block-sized chunks so they can be easily appended to blocks. The first such "token block" may have less token ids than the block size, since the last allocated block may be partially @@ -324,6 +324,6 @@ def _chunk_token_blocks_for_append( """ first_chunk_size = self._block_size - (self._num_full_slots % self._block_size) - token_blocks = [token_ids[:first_chunk_size]] + chunk_list( + token_blocks = [token_ids[:first_chunk_size]] + chunk_array( token_ids[first_chunk_size:], self._block_size) return token_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 255aae9d1731..009439708476 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -1,5 +1,7 @@ from typing import Dict, FrozenSet, List, Optional, Tuple +import numpy as np + from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator @@ -131,15 +133,15 @@ def allocate_mutable(self, prev_block: Optional[Block], return self._allocators[device].allocate_mutable(prev_block) def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + token_ids: np.ndarray, device: Device) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. Args: prev_block (Optional[Block]): The previous block in the sequence. Used for prefix hashing. - token_ids (List[int]): The list of token IDs to be stored in the new - block. + token_ids (np.ndarray): The list of token IDs to be stored in the + new block. device (Device): The device on which to allocate the new block. Returns: @@ -326,7 +328,7 @@ def __init__(self, proxy: Block): super().__init__() self._proxy = proxy - def append_token_ids(self, token_ids: List[BlockId]): + def append_token_ids(self, token_ids: np.ndarray): raise ValueError("null block should not be modified") @property @@ -338,13 +340,17 @@ def block_id(self, value: Optional[BlockId]): raise ValueError("null block should not be modified") @property - def token_ids(self) -> List[BlockId]: + def token_ids(self) -> np.ndarray: return self._proxy.token_ids @property def num_empty_slots(self) -> BlockId: return self._proxy.num_empty_slots + @property + def num_tokens(self) -> BlockId: + return self._proxy.num_tokens + @property def is_full(self): return self._proxy.is_full diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 4b20856a1b42..3e25581ee432 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple +import numpy as np + from vllm.utils import Device BlockId = int @@ -9,7 +11,7 @@ class Block(ABC): @abstractmethod - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, token_ids: np.ndarray) -> None: pass @property @@ -25,7 +27,7 @@ def block_id(self, value: Optional[int]) -> None: @property @abstractmethod - def token_ids(self) -> List[int]: + def token_ids(self) -> np.ndarray: pass @property @@ -33,6 +35,11 @@ def token_ids(self) -> List[int]: def num_empty_slots(self) -> int: pass + @property + @abstractmethod + def num_tokens(self) -> int: + pass + @property @abstractmethod def is_full(self) -> bool: @@ -70,7 +77,7 @@ class Factory(Protocol): def __call__( self, prev_block: Optional["Block"], - token_ids: List[int], + token_ids: np.ndarray, block_size: int, allocator: "BlockAllocator", block_id: Optional[int] = None, @@ -97,7 +104,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + token_ids: np.ndarray) -> Block: pass @abstractmethod @@ -180,7 +187,7 @@ def allocate_mutable(self, prev_block: Optional[Block], @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + token_ids: np.ndarray, device: Device) -> Block: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 50f27bab3377..6ddf26882940 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,5 +1,9 @@ +import weakref from typing import FrozenSet, Iterable, List, Optional, Set, Tuple +import numpy as np + +from vllm.array_pool import alloc_array, del_array from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device @@ -51,7 +55,7 @@ def __init__( def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: np.ndarray, device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -87,7 +91,7 @@ def allocate_mutable(self, block_id = self._allocate_new_block_id() return self._create_block( prev_block=prev_block, - token_ids=[], + token_ids=np.zeros(0, dtype=np.int64), block_id=block_id, block_size=self._block_size, allocator=self, @@ -286,7 +290,7 @@ class NaiveBlock(Block): Args: prev_block (Block): The previous block in the sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. + token_ids (np.ndarray): The initial token IDs to be stored in the block. block_size (int): The maximum number of token IDs that can be stored in the block. allocator (BlockAllocator): The block allocator associated with this @@ -300,12 +304,13 @@ class NaiveBlock(Block): def __init__(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: np.ndarray, block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, _cow_target: Optional[Block] = None): - self._token_ids: List[int] = [] + self._token_ids: np.ndarray = alloc_array(block_size) + self._num_tokens = 0 self._block_size = block_size self._prev_block = prev_block self._block_id = block_id @@ -313,8 +318,9 @@ def __init__(self, self._cow_target = _cow_target if _cow_target is not None else self self._append_token_ids_no_cow(token_ids) + self._finalizer = weakref.finalize(self, del_array, self._token_ids) - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, token_ids: np.ndarray) -> None: """Appends the given token IDs to the block, instructing the allocator to perform a copy-on-write if necessary. @@ -327,9 +333,12 @@ def append_token_ids(self, token_ids: List[int]) -> None: self._block_id = (self._allocator.cow_block_if_not_appendable( self._cow_target)) - def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: - assert self.num_empty_slots >= len(token_ids) - self._token_ids.extend(token_ids) + def _append_token_ids_no_cow(self, token_ids: np.ndarray) -> None: + len_new_tokens = len(token_ids) + new_len = self._num_tokens + len_new_tokens + assert new_len <= self._block_size + self._token_ids[self._num_tokens:new_len] = token_ids + self._num_tokens = new_len @property def computed(self) -> bool: @@ -357,14 +366,18 @@ def block_id(self, value: Optional[int]) -> None: @property def is_full(self) -> bool: - return self.num_empty_slots == 0 + return self._num_tokens == self._block_size @property def num_empty_slots(self) -> int: - return self._block_size - len(self._token_ids) + return self._block_size - self._num_tokens + + @property + def num_tokens(self) -> int: + return self._num_tokens @property - def token_ids(self) -> List[int]: + def token_ids(self) -> np.ndarray: return self._token_ids @property diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 2df7d74e4ff1..0cd88b647343 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,6 +4,8 @@ from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple +import numpy as np + from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device @@ -76,7 +78,7 @@ def __init__( def _create_block( self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: np.ndarray, block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, @@ -96,7 +98,7 @@ def _create_block( def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: np.ndarray, device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -182,7 +184,7 @@ def allocate_mutable(self, # its kvcache block = self._create_block( prev_block=prev_block, - token_ids=[], + token_ids=np.zeros(0, dtype=np.int32), block_size=self._block_size, allocator=self, block_id=block_id, @@ -504,7 +506,7 @@ class PrefixCachingBlock(Block): Args: prev_block (Optional[PrefixCachingBlock]): The previous block in the sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. + token_ids (np.ndarray): The initial token IDs to be stored in the block. block_size (int): The maximum number of token IDs that can be stored in the block. prefix_caching_allocator (BlockAllocator): The prefix @@ -516,7 +518,7 @@ class PrefixCachingBlock(Block): def __init__( self, prev_block: Optional[Block], - token_ids: List[int], + token_ids: np.ndarray, block_size: int, prefix_caching_allocator: BlockAllocator, block_id: Optional[int] = None, @@ -560,14 +562,14 @@ def last_accessed(self) -> float: def last_accessed(self, last_accessed_ts: float): self._last_accessed = last_accessed_ts - def append_token_ids(self, token_ids: List[int]) -> None: + def append_token_ids(self, token_ids: np.ndarray) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. Internally, the naive block handles CoW. Args: - token_ids (List[int]): The token IDs to be appended to the block. + token_ids (np.ndarray): The token IDs to be appended to the block. """ assert token_ids @@ -597,6 +599,10 @@ def is_full(self) -> bool: def num_empty_slots(self) -> int: return self._block.num_empty_slots + @property + def num_tokens(self) -> int: + return self._block._num_tokens + @property def num_tokens_total(self) -> int: """return the total tokens so far. @@ -623,7 +629,7 @@ def block_size(self) -> int: return self._block.block_size @property - def token_ids(self) -> List[int]: + def token_ids(self) -> np.ndarray: return self._block.token_ids @property @@ -666,7 +672,7 @@ def content_hash(self) -> Optional[int]: @staticmethod def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], - cur_block_token_ids: List[int]) -> int: + cur_block_token_ids: np.ndarray) -> int: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. @@ -678,7 +684,7 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], the sequence. - prev_block_hash (Optional[int]): The hash of the previous block. None if this is the first block. - - cur_block_token_ids (List[int]): A list of token ids in the current + - cur_block_token_ids (np.ndarray): A list of token ids in the current block. The current block is assumed to be full. Returns: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4b427b1fb2f2..8cadaf3ff567 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -464,8 +464,12 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self._get_eos_token_id(lora_request) - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + max_seq_len=self.model_config.max_model_len) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): diff --git a/vllm/sequence.py b/vllm/sequence.py index c618c3692611..08c8f4d1f4af 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,13 +1,17 @@ """Sequence and its related classes.""" import copy import enum +import hashlib import math +import weakref from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +import numpy as np import torch +from vllm.array_pool import alloc_array, del_array from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -107,55 +111,76 @@ class SequenceData: prompt_token_ids: The token IDs of the prompt. output_token_ids: The token IDs of the output. Set to an empty list if None. + max_seq_len: The maximum sequence length. A buffer of this size is + allocated for the tokens. By default, it is set to 16k for test + purposes. During inference, it should be set to the maximum + sequence length of the model. Attributes: prompt_token_ids: The token IDs of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. + tokens: array of all the tokens (prompt + output) + + NOTE: special care must be taken regarding data copy and returning `list` + or `np.ndarray` from this class. `get_prompt_token_ids` and + `get_output_token_ids` return `list`. They are used to construct + request output that will be returned to the user. They need to be + Python lists. And they are actually quite cheap, because the function + returns a reference to the list that is already stored in the object. + `get_token_ids` returns a view of `np.ndarray`. It avoids data copy, + and also allows array operations to be performed on the data. """ def __init__( self, prompt_token_ids: List[int], output_token_ids: Optional[List[int]] = None, + max_seq_len: Optional[int] = None, ) -> None: + self.max_seq_len = max_seq_len or 16 * 1024 + self.tokens = alloc_array(self.max_seq_len) + self.prompt_token_ids_list = prompt_token_ids + self.num_prompt_tokens = len(prompt_token_ids) + self.tokens[:self.num_prompt_tokens] = prompt_token_ids if output_token_ids is None: output_token_ids = [] - - self.prompt_token_ids = prompt_token_ids - self._prompt_token_ids_tuple = tuple(prompt_token_ids) - self.output_token_ids = output_token_ids + self.num_output_tokens = len(output_token_ids) + self.output_token_ids_list = output_token_ids + self.tokens[self.num_prompt_tokens:self.num_prompt_tokens + + self.num_output_tokens] = output_token_ids self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL + self._finalizer = weakref.finalize(self, del_array, self.tokens) def append_token_id(self, token_id: int, logprob: float) -> None: - self.output_token_ids.append(token_id) + self.tokens[self.num_prompt_tokens + self.num_output_tokens] = token_id + self.output_token_ids_list.append(token_id) + self.num_output_tokens += 1 self.cumulative_logprob += logprob def get_len(self) -> int: - return len(self.output_token_ids) + len(self.prompt_token_ids) + return self.num_prompt_tokens + self.num_output_tokens def get_prompt_len(self) -> int: - return len(self.prompt_token_ids) + return self.num_prompt_tokens def get_output_len(self) -> int: - return len(self.output_token_ids) + return self.num_output_tokens - def get_token_ids(self) -> List[int]: - return self.prompt_token_ids + self.output_token_ids + def get_token_ids(self) -> np.ndarray: + return self.tokens[:self.num_prompt_tokens + self.num_output_tokens] - def get_prefix_token_ids( - self, num_tokens: int - ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + def hash_prefix_token_ids(self, num_tokens: int) -> bytes: """Get prefix tokens, and make the return value hashable""" - prompt_length = len(self.prompt_token_ids) - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self.output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) + data = self.tokens[:num_tokens] + # get a memory view of the underlying data + buffer = memoryview(data) # type: ignore + # hash the memory view + hash_value = hashlib.sha256(buffer).digest() + return hash_value def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" @@ -186,15 +211,14 @@ def get_num_uncomputed_tokens(self) -> int: return self.get_len() - self.get_num_computed_tokens() def get_last_token_id(self) -> int: - if not self.output_token_ids: - return self.prompt_token_ids[-1] - return self.output_token_ids[-1] + return int(self.tokens[self.num_prompt_tokens + + self.num_output_tokens - 1]) def get_prompt_token_ids(self) -> List[int]: - return self.prompt_token_ids + return self.prompt_token_ids_list def get_output_token_ids(self) -> List[int]: - return self.output_token_ids + return self.output_token_ids_list @property def stage(self) -> SequenceStage: @@ -202,8 +226,8 @@ def stage(self) -> SequenceStage: def __repr__(self) -> str: return (f"SequenceData(" - f"prompt_token_ids={self.prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " + f"prompt_token_ids={self.get_prompt_token_ids()}, " + f"output_token_ids={self.get_output_token_ids()}, " f"cumulative_logprob={self.cumulative_logprob})") @@ -225,6 +249,7 @@ def __init__( block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, + max_seq_len: Optional[int] = None, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -232,10 +257,15 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(self.prompt_token_ids) + self.max_seq_len = max_seq_len + self.data = SequenceData(self.inputs["prompt_token_ids"], + max_seq_len=max_seq_len) + self.prompt_token_ids: List[int] = self.inputs["prompt_token_ids"] + self.prompt: Optional[str] = self.inputs.get("prompt") self.output_logprobs: SampleLogprobs = [] self.output_text = "" + # Initialize the logical token blocks with the prompt token ids. self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -247,15 +277,7 @@ def __init__( @property def n_blocks(self) -> int: - return math.ceil(self.get_len() / self.block_size) - - @property - def prompt(self) -> Optional[str]: - return self.inputs.get("prompt") - - @property - def prompt_token_ids(self) -> List[int]: - return self.inputs["prompt_token_ids"] + return math.ceil(self.data.get_len() / self.block_size) @property def multi_modal_data(self) -> Optional["MultiModalData"]: @@ -278,8 +300,8 @@ def hash_of_block(self, logical_idx: int) -> int: # TODO: The current hashing function is O(L^2). We should optimize # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) + tokens_hash = self.data.hash_prefix_token_ids(num_tokens) + return hash((tokens_hash, self.lora_int_id)) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size @@ -306,7 +328,7 @@ def get_prompt_len(self) -> int: def get_output_len(self) -> int: return self.data.get_output_len() - def get_token_ids(self) -> List[int]: + def get_token_ids(self) -> np.ndarray: return self.data.get_token_ids() def get_prompt_token_ids(self) -> List[int]: @@ -316,7 +338,7 @@ def get_last_token_id(self) -> int: return self.data.get_last_token_id() def get_output_token_ids(self) -> List[int]: - return self.data.output_token_ids + return self.data.get_output_token_ids() def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 40516556344e..08c77dd0811d 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -283,6 +283,7 @@ def _create_single_target_seq_group_metadata( SequenceData( prompt_token_ids=prompt_token_ids, output_token_ids=new_output_token_ids, + max_seq_len=seq_data.max_seq_len, ), } # This is a hack. Technically, spec decoding should compute diff --git a/vllm/utils.py b/vllm/utils.py index 92abdb3fb9b1..c4ccc0c45662 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -393,6 +393,11 @@ def chunk_list(lst: List[T], chunk_size: int) -> List[List[T]]: return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] +def chunk_array(lst: np.ndarray, chunk_size: int) -> List[np.ndarray]: + """Yield successive chunk_size chunks from lst.""" + return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] + + def cdiv(a: int, b: int) -> int: """Ceiling division.""" return -(a // -b) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 181442490a82..5442b00ec75f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -314,8 +314,9 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - input_tokens: List[int] = [] - input_positions: List[int] = [] + batch_size = 0 + input_tokens: List[np.ndarray] = [] + input_positions: List[np.ndarray] = [] slot_mapping: List[int] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] @@ -390,9 +391,7 @@ def _prepare_model_input_tensors( if is_prompt: tokens = seq_data.get_token_ids()[context_len:seq_len] else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] + tokens = seq_data.get_token_ids()[-1:] # Prefix cache was hit. # Prefix is not supported with sliding_window @@ -476,8 +475,9 @@ def _prepare_model_input_tensors( context_lens.append(sliding_context_len) query_len = sliding_seq_len - sliding_context_len query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) + input_tokens.append(tokens) + batch_size += seq_len - context_len + input_positions.append(np.arange(context_len, seq_len)) lora_id = seq_group_metadata.lora_int_id if is_prompt: @@ -554,7 +554,6 @@ def _prepare_model_input_tensors( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - batch_size = len(input_tokens) max_query_len = max(query_lens) max_prefill_seq_len = max(prefill_seq_lens, default=0) max_decode_seq_len = max(decode_seq_lens, default=0) @@ -569,9 +568,11 @@ def _prepare_model_input_tensors( if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size + zero_pad_array = np.zeros(graph_batch_size - batch_size, + dtype=np.int64) + input_tokens.append(zero_pad_array) + input_positions.append(zero_pad_array) for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) - input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) seq_lens.append(1) block_tables.append([]) @@ -611,12 +612,12 @@ def _prepare_model_input_tensors( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) + input_tokens_array = np.concatenate(input_tokens) + input_tokens_tensor = torch.from_numpy(input_tokens_array).to( + device=self.device) + input_positions_array = np.concatenate(input_positions) + input_positions_tensor = torch.from_numpy(input_positions_array).to( + device=self.device) slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.long, device=self.device)