From 996cf2de5cf7bc5aa7ab452c02ecda50e2d0cdcc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 2 Jun 2024 00:01:30 +0000 Subject: [PATCH 1/2] Fix hashing logic for non-full blocks --- vllm/block.py | 3 +- vllm/core/block_manager_v1.py | 56 ++++++++++++++--------------------- vllm/sequence.py | 25 +++++++++++----- 3 files changed, 41 insertions(+), 43 deletions(-) diff --git a/vllm/block.py b/vllm/block.py index 2cc6b947f225..3549ee4e9d77 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,5 +1,5 @@ """Token blocks.""" -from typing import List +from typing import Optional, List from vllm.utils import Device @@ -25,6 +25,7 @@ def __init__( self.token_ids = [_BLANK_TOKEN_ID] * block_size self.num_tokens = 0 + self.block_hash: Optional[int] = None def is_empty(self) -> bool: return self.num_tokens == 0 diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 201cba309f6e..2fee1eb2485a 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -262,8 +262,7 @@ def __init__( self.cross_block_tables: Dict[str, BlockTable] = {} def _get_seq_num_required_blocks(self, seq: Sequence) -> int: - return 0 if seq is None \ - else len(seq.logical_token_blocks) + return 0 if seq is None else len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -275,8 +274,8 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) cross_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_encoder_seq()) - num_required_blocks = self_num_required_blocks + \ - cross_num_required_blocks + num_required_blocks = (self_num_required_blocks + + cross_num_required_blocks) if self.block_sliding_window is not None: @@ -293,9 +292,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def _allocate_sequence(self, \ - seq: Sequence, \ - ref_count: int, \ + def _allocate_sequence(self, + seq: Sequence, + ref_count: int, is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -328,10 +327,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - block_table: BlockTable = \ - self._allocate_sequence(seq, - seq_group.num_seqs(), - is_encoder_decoder) + block_table: BlockTable = self._allocate_sequence( + seq, seq_group.num_seqs(), is_encoder_decoder) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -368,6 +365,7 @@ def _promote_last_block( # Compute a new hash for the block so that it can be shared by other # Sequences new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) + assert new_hash is not None, "Last block is not full." # if new_hash is already in the cached table, then free last_block # and return the cached version @@ -406,9 +404,7 @@ def _allocate_last_physical_block( # content hash. if not self.enable_caching: return self.gpu_allocator.allocate() - block_hash: Optional[int] = None - if (self._is_last_block_full(seq)): - block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) + block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) num_hashed_tokens = seq.num_hashed_tokens_of_block( len(seq.logical_token_blocks) - 1) @@ -553,18 +549,14 @@ def swap_in(self, # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.cpu_allocator, - self.gpu_allocator, - mapping) + self.block_tables[seq.seq_id] = self._swap_block_table( + self.block_tables[seq.seq_id], self.cpu_allocator, + self.gpu_allocator, mapping) if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.cpu_allocator, - self.gpu_allocator, - mapping) + self.cross_block_tables[request_id] = self._swap_block_table( + self.cross_block_tables[request_id], self.cpu_allocator, + self.gpu_allocator, mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] @@ -580,18 +572,14 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.gpu_allocator, - self.cpu_allocator, - mapping) + self.block_tables[seq.seq_id] = self._swap_block_table( + self.block_tables[seq.seq_id], self.gpu_allocator, + self.cpu_allocator, mapping) if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.gpu_allocator, - self.cpu_allocator, - mapping) + self.cross_block_tables[request_id] = self._swap_block_table( + self.cross_block_tables[request_id], self.gpu_allocator, + self.cpu_allocator, mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] diff --git a/vllm/sequence.py b/vllm/sequence.py index ac5c234d052b..e05356e9f61a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -269,15 +269,24 @@ def get_output_text_to_return(self, buffer_length: int): return self.output_text[:-buffer_length] if truncate else ( self.output_text) - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size - - # Compute the number of tokens in the sequence + def hash_of_block(self, logical_idx: int) -> Optional[int]: + """Return the hash of the block if it is full.""" # TODO: The current hashing function is O(L^2). We should optimize # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) + assert logical_idx < len(self.logical_token_blocks), ( + f"logical_idx={logical_idx} is out of range for " + f"logical_token_blocks={len(self.logical_token_blocks)}") + block = self.logical_token_blocks[logical_idx] + if block.block_hash is not None: + return block.block_hash + if not block.is_full(): + return None + num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx) + hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens) + block_hash = hash((hashed_tokens, self.lora_int_id)) + # Cache the block hash for future use. + block.block_hash = block_hash + return block_hash def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size @@ -632,7 +641,7 @@ class SequenceGroupMetadata: state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated From 1936d7bab00332047c444f04b7c01276e33cb8bb Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 2 Jun 2024 00:02:54 +0000 Subject: [PATCH 2/2] format --- vllm/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/block.py b/vllm/block.py index 3549ee4e9d77..e3231ae592f9 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,5 +1,5 @@ """Token blocks.""" -from typing import Optional, List +from typing import List, Optional from vllm.utils import Device