Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
6c55024
Enabling bonus token in speculative decoding for KV cache based models
sroy745 Jun 22, 2024
4e27a2c
Fixing comments and reverting changes in conftest.py
sroy745 Jun 22, 2024
fc20e26
Reverting changes in conftest
sroy745 Jun 22, 2024
55a8b2b
Fixing a syntax error
sroy745 Jun 22, 2024
2d7eb37
Fix syntax dif
sroy745 Jun 22, 2024
8c129ed
Add new SpecDecodeWorker tests
sroy745 Jun 24, 2024
12c6fad
Merge branch 'main' into vllm-sd-enable-bonus-token-1
sroy745 Jun 24, 2024
af94752
Fix function signatures and add test
sroy745 Jun 24, 2024
e96ac72
Fix interfaces
sroy745 Jun 24, 2024
d39c82f
Reduce num speculations in tests
sroy745 Jun 24, 2024
efce280
More cleanup and comments
sroy745 Jun 24, 2024
4ef9f24
Revert vllm/spec_decode/util.py
sroy745 Jun 24, 2024
3b50193
Revert vllm/spec_decode/util.py
sroy745 Jun 24, 2024
c4fae4f
Remove space
sroy745 Jun 24, 2024
5f9772c
More fixes
sroy745 Jun 24, 2024
79608c9
Fix passing function parameter
sroy745 Jun 24, 2024
69d0d47
Fix est_dynamic_spec_decode.py
sroy745 Jun 24, 2024
f7f3fd7
Merge branch 'main' into vllm-sd-enable-bonus-token-1
sroy745 Jul 2, 2024
bcadab2
enable bonus tokens always
sroy745 Jul 2, 2024
85d464f
Remove argument from arg_utils.py
sroy745 Jul 2, 2024
b3e1dda
Update spec_decode_worker
sroy745 Jul 2, 2024
32e9162
Fix a test
sroy745 Jul 5, 2024
16d300b
Merge branch 'main' into vllm-sd-enable-bonus-token-1
sroy745 Jul 6, 2024
f05bebe
Addressed comments
sroy745 Jul 8, 2024
327b595
Fix merge issues for spec_decode_worker.py
sroy745 Jul 8, 2024
d0c9634
Remove additional definition of variable
sroy745 Jul 8, 2024
d2f9dcc
Revert vllm/config.py
sroy745 Jul 8, 2024
847087c
Fix test
sroy745 Jul 8, 2024
e689de0
Revert changes to vllm/config.py
sroy745 Jul 8, 2024
fd63121
Fix config.py
sroy745 Jul 8, 2024
ca49aa1
Fix config.py
sroy745 Jul 8, 2024
c245d7a
Address comments
sroy745 Jul 10, 2024
fe06b46
Fix format
sroy745 Jul 10, 2024
72d6857
Revert a change
sroy745 Jul 10, 2024
aea0bbf
Make tests concise
sroy745 Jul 10, 2024
decae31
Fix a test comment
sroy745 Jul 10, 2024
1b459c3
Merge pull request #6 from vllm-project/main
sroy745 Jul 10, 2024
c9dcd58
Rebasing to include medusa_worker.py
sroy745 Jul 10, 2024
72217c1
Add comment
sroy745 Jul 10, 2024
fe84b0d
Add note
sroy745 Jul 10, 2024
9f9c5b7
Dummy commit
sroy745 Jul 10, 2024
780ac4d
Dummy commit
sroy745 Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size
212 changes: 207 additions & 5 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_same_output_for_single_step():
actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps)
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps
actual_output = actual_output[0]

Expand Down Expand Up @@ -210,7 +211,8 @@ def test_same_output_for_multi_step():
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps)
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())

# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
Expand Down Expand Up @@ -277,6 +279,203 @@ def test_same_output_for_multi_step():
single_step_logprobs)


@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'

block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)

# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)

# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)

# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)


@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'

block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)

# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)

# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)


@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences
Expand Down Expand Up @@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down
9 changes: 6 additions & 3 deletions tests/spec_decode/test_ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down Expand Up @@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)

assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
Expand Down
Loading