Skip to content
178 changes: 178 additions & 0 deletions examples/offline_inference/spans/spans_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
import os
import time
import random

# necessary for spans to work
os.environ["VLLM_USE_V1"] = "1"
# to ensure deterministic behaviour
os.environ["TOKENIZERS_PARALLELISM"] = "False"

# in case you need it
os.environ['VLLM_ATTENTION_BACKEND'] = "TRITON_ATTN_VLLM_V1"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = '0'

# standard imports
from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt


# helper functions
def pad(toklist):
padtok = int(os.environ.get("VLLM_V1_SPANS_TOKEN_PAD", None))
return toklist[:-1] + [padtok] * ((16 - len(toklist)) % 16) + toklist[-1:]


def avg(list_of_numbers):
return sum(list_of_numbers) / max(len(list_of_numbers), 1)


def wrap(prompt):
if isinstance(prompt[0], list):
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
return TokensPrompt(prompt_token_ids=prompt)

def initialize_vllm(model,
temp=0.6,
logprobs=None,
max_toks=131072,
max_generated_toks=1):
# boot up vLLM
samp_params_preload = SamplingParams(temperature=temp, max_tokens=1)
samp_params_generate = SamplingParams(temperature=temp,
max_tokens=max_generated_toks,
logprobs=logprobs)
llm = LLM(
model=model,
gpu_memory_utilization=0.9,
enforce_eager=True, # <- so it boots faster
block_size=16,
max_model_len=max_toks,
max_num_seqs=4,
)
tok = llm.get_tokenizer()
tok_fun = lambda x: tok.convert_tokens_to_ids(tok.tokenize(x))
return samp_params_preload, samp_params_generate, tok_fun, llm


def main():
model_names = [
"ldsjmdy/Tulu3-Block-FT", # <- finetuned to handle block-attention
"ldsjmdy/Tulu3-RAG", # <- baseline
]
model_name = model_names[0]

# tokens that need to be set to perform block-attention
PAD_TOK = 27 # <- "<"
SPAN_TOK = 10 # <- "+"
SPAN_RECOMP_TOK = 31 # <- "@"

# vLLM-specific env vars

# enables block attention
# -> when this line is not commented, we expect a speedup
# in the execution of the last two .generate calls
os.environ['VLLM_V1_SPANS_ENABLED'] = 'True'

# the token that tells vLLM "this is the beginning of a span"
os.environ['VLLM_V1_SPANS_TOKEN_PLUS'] = str(SPAN_TOK)

# token that tells vLLM:
# "from here on, recompute KV vectors if any previous tokens differ"
os.environ['VLLM_V1_SPANS_TOKEN_CROSS'] = str(SPAN_RECOMP_TOK)

# will print every step of the span process if set to true
# os.environ['VLLM_V1_SPANS_DEBUG'] = 'True'

# will disable the adjustment of positional encodings when a KV cache
# block is loaded to a different position than it was stored
# -> when this line is not commented,
# spans overlap in their positional encodings
os.environ['VLLM_V1_SPANS_DISABLE_REPOSITION'] = 'True'

# general env vars

# our helper function uses this token to pad spans
os.environ['VLLM_V1_SPANS_TOKEN_PAD'] = str(PAD_TOK)

# now we instantiate the model
samp_params_preload, samp_params_generate, tok, llm = initialize_vllm(
model_name, max_generated_toks=1)
# model_name, max_generated_toks=1, max_toks=2048)

# components of the prompt template
prefix = pad(
[SPAN_RECOMP_TOK] + tok("<|system|>\nYou are an intelligent AI assistant. " \
"Please answer questions based on the user's instructions. " \
"Below are some reference documents that may help you in " \
"answering the user's question."
))
midfx = [SPAN_RECOMP_TOK] + tok(
"<|user|>\nPlease write a high-quality answer for the " \
"given question using only the provided search documents " \
"(some of which might be irrelevant).\nQuestion: "
)
postfx = tok('''\n<|assistant|>\n''')

print("---->", postfx)

times = []
for ndocs in [1, 2, 4, 8]:
for dlen in [512, 1024, 2048, 4096, 8192]:
print(f"<!> DOCLENGTH {dlen} NUMDOCS {ndocs}")

doc_toks = tok(
"Sequence Transduction Models and Template-Assisted Selective Epitaxy")
docs = [pad([SPAN_TOK] +
random.choices(doc_toks, k=dlen))
for _ in range(ndocs)]

# user query
query = midfx + tok(
"Tell me which one concerns deep learning. " \
"Indicate your answer with a number in brackets."
) + postfx

for i in range(3):
print(f"<!> ITERATION {i}")

# preload documents
ts_pre = time.time()
llm.generate(
[wrap(d) for d in docs] + [wrap(prefix)],
sampling_params=samp_params_preload, use_tqdm=False)
te_pre = time.time() - ts_pre

ts_gen = time.time()

# this now will load prefix, doc_a, doc_b,
# from the KV cache regardless of the order
random.shuffle(docs)
llm.generate(wrap(prefix + \
sum(docs, []) + \
query),
sampling_params=samp_params_generate, use_tqdm=False)

# this should also run faster:
random.shuffle(docs)
llm.generate(wrap(prefix + \
sum(docs, []) + \
query),
sampling_params=samp_params_generate, use_tqdm=False)

te_gen = time.time() - ts_gen

print(f"doc preload time / TTFT : {te_pre:.4f} / {te_gen:.4f} (s)")
times.append(dict(
preload_time=te_pre,
gen_time=te_gen,
it=i,
doc_len=dlen,
num_docs=ndocs,
))


if __name__ == '__main__':
main()
26 changes: 20 additions & 6 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,17 @@ def __init__(
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []

def _closest_cache_hit(
self, cached_blocks: dict[int, KVCacheBlock],
position: int,
) -> dict[int, KVCacheBlock]:
return min(list(cached_blocks.values()),
key=lambda x: abs(x.position - position))

def get_cached_block(
self, block_hash: BlockHash,
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
kv_cache_group_ids: list[int],
position: Optional[int] = None) -> Optional[list[KVCacheBlock]]:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
Expand All @@ -95,7 +103,11 @@ def get_cached_block(
block_hash_with_group_id)
if not cached_blocks_one_group:
return None
first_block = next(iter(cached_blocks_one_group.values()))
if position is not None and len(cached_blocks_one_group) > 1:
first_block = self._closest_cache_hit(cached_blocks_one_group,
position)
else:
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
return cached_blocks

Expand Down Expand Up @@ -193,17 +205,19 @@ def _set_block_positions(self, new_full_blocks: list[KVCacheBlock],
debug logging that prints each block's tokens, to help
debug span-related workflows.
"""
dbg = envs.VLLM_V1_SPANS_DEBUG
pos = 0
nfb_ids = {b.block_id for b in new_full_blocks}
for blk in blocks:
if blk in new_full_blocks:
if blk.block_id in nfb_ids:
blk.position = pos
if envs.VLLM_V1_SPANS_DEBUG:
if dbg:
# this prints the tokens assigned to a new block
# in the KV cache
blk_tks = request.all_token_ids[pos:pos + 16]
assert blk.block_hash is not None
bhash = str(abs(blk.block_hash.block_hash.hash_value)
)[:4] if blk.block_hash.block_hash else None
bhash = str(blk.block_hash
)[:4] if blk.block_hash else None
print('[SPANS -> block_pool] assigning to pos', pos,
'with hash', bhash, 'block: ', blk_tks)
pos += 16
Expand Down
83 changes: 51 additions & 32 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

@dataclass
class BlockRepositionRequest:
block_id: int
kvc_pos: int
prompt_pos: int
cached_pos: int
cached_blockid: int
prompt_blockpos: int
prompt_reqid: str


@dataclass
Expand Down Expand Up @@ -190,13 +192,48 @@ def get_computed_blocks(self,
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(request.block_hashes,
max_cache_hit_length))

repo_reqs = []
if envs.VLLM_V1_SPANS_ENABLED:
# now we check how many of those computed blocks have incorrect or are
# after an incorrect position match
# our own positions are clear, now we need to compare that to cached
# positions
non_match_idx = -1
non_match_found = False
for i, block in enumerate(computed_blocks[0]):
if block.is_null: # null blocks don't have meaningful position
continue
prompt_pos = self.block_size * i
cached_pos = block.position
# find first block id where pos didn't match
if prompt_pos != cached_pos and not non_match_found:
non_match_found = True
non_match_idx = i
# record from then on and after, repo requests
if non_match_found:
repo_reqs.append(
BlockRepositionRequest(
prompt_pos,
cached_pos,
block.block_id,
i,
request.request_id))
# if any repo is needed, we need to exclude that from the
# computed blocks and num_new_computed_tokens, so that
# new blocks get allocated that we can copy kv values to
if non_match_found:
computed_blocks = (computed_blocks[0][:non_match_idx],)
num_new_computed_tokens = len(computed_blocks[0]) * self.block_size


if envs.VLLM_V1_SPANS_DEBUG:
print(
"[SPANS -> kv_cache_manager] here's the blocks hashed in " \
"this request:",
[str(abs(b.hash_value))[:4] for b in request.block_hashes])
[str(b)[-4:] for b in request.block_hashes])
kvcache_contents = [
str(abs(b.block_hash.block_hash.hash_value))[:4]
str(b.block_hash)[-4:]
if b.block_hash else None for b in self.block_pool.blocks
if b._block_hash
]
Expand All @@ -212,43 +249,25 @@ def get_computed_blocks(self,
"[SPANS -> kv_cache_manager] here's the number of blocks " \
"that hit the cache:",
[
str(abs(b.block_hash.block_hash.hash_value))[:4]
str(b.block_hash)[-4:]
if b.block_hash else None for b in computed_blocks[0]
])

blocks_to_reposition = []
if envs.VLLM_V1_SPANS_ENABLED:
# Spans does yet not support hybrid models
assert len(computed_blocks) == 1
for i, b in enumerate(computed_blocks[0]):
prompt_pos = i * 16
kvc_pos = b.position
if envs.VLLM_V1_SPANS_DEBUG:
print(
f"[SPANS -> kv_cache_manager] checking block " \
f"{b.block_id} with prompot pos {prompt_pos} " \
f"and kv pos {kvc_pos}"
)
assert isinstance(kvc_pos, int)
if kvc_pos != prompt_pos:
if envs.VLLM_V1_SPANS_DEBUG:
print(
f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} "\
f"to prompt pos: {prompt_pos} repositioning needed"
)

blocks_to_reposition.append(
BlockRepositionRequest(b.block_id, kvc_pos,
prompt_pos))
b.position = int(prompt_pos)
# for block duplication
num_repo = len([r for r in repo_reqs
if r.prompt_pos != r.cached_pos])
num_copy = len(repo_reqs) - num_repo
print(
"[SPANS -> kv_cache_manager] here's the number of blocks",
f"total: {len(repo_reqs)} to reposition: {num_repo},",
f"to copy: {num_copy}")

if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens

return KVCacheBlocks(computed_blocks, blocks_to_reposition),\
return KVCacheBlocks(computed_blocks, repo_reqs),\
num_new_computed_tokens

def allocate_slots(
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def append_n(self, blocks: list[KVCacheBlock]) -> None:
"""
if len(blocks) == 0:
return
blocks = list({b.block_id: b for b in blocks}.values())
self.num_free_blocks += len(blocks)

last_block = self.fake_free_list_tail.prev_free_block
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ def schedule(self) -> SchedulerOutput:
len(new_computed_blocks.blocks_to_reposition) > 0:
blocks_to_reposition.extend(
new_computed_blocks.blocks_to_reposition)

# TODO (Nathan) find something smarter to do than this
token_budget += \
len(new_computed_blocks.blocks_to_reposition) \
* self.block_size

# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
Expand Down Expand Up @@ -545,8 +550,10 @@ def schedule(self) -> SchedulerOutput:
self.waiting.prepend_requests(skipped_waiting_requests)

# Check if the scheduling constraints are satisfied.
# TODO make this smarter for spans
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + \
len(blocks_to_reposition) * self.block_size
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
Expand Down
Loading