-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models #5765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models #5765
Conversation
Pull from head
LMK once it's ready for review @sroy745 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approach looks good. seems this will conflict with #5799. we should work out a way to combine both.
and batch sizes when bonus token acceptance is enabled. It ensures | ||
correctness by comparing the output of speculative decoding with the baseline. | ||
""" | ||
run_greedy_equality_correctness_test(baseline_llm_generator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test looks good. Can you manually check the draft acceptance rate? for the same draft/target model it should be 100%. without your fix (and with bonus token enabled) the acceptance rate goes to like 80%.
Ideally we have a test for this, not sure how easy that is..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I thought of adding such an e2e test but I could find an easy way to access the metrics_collector and the stats. As you suggested I added an e2e test in test_multistep_correctness.py with both draft and target model as same ("JackFram/llama-68m") and number of speculation as 3. The System efficiency with bonus token enabled is 1 and with bonus token disabled is 0.75
With bonus token
INFO 07-08 06:27:27 metrics.py:316] Speculative metrics: Draft acceptance rate: 1.000, System efficiency: 1.000, Number of speculative tokens: 3, Number of accepted tokens: 21696, Number of draft tokens tokens: 21696, Number of emitted tokens tokens: 28928.
Without bonus token
Speculative metrics: Draft acceptance rate: 1.000, System efficiency: 0.750, Number of speculative tokens: 3, Number of accepted tokens: 21696, Number of draft tokens tokens: 21696, Number of emitted tokens tokens: 21696.
indices_of_original_sequence_groups = [] | ||
for seq_group in execute_model_req.seq_group_metadata_list: | ||
seq_ids_with_bonus_tokens = [] | ||
for seq_id, seq_data in seq_group.seq_data.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approach looks good. pretty messy but not anything easy to fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to simplify it a bit by moving some of the logic to 2 helper functions _shallow_copy_seq_group_metadata and _copy_seq_metadata_excluding_last_token. PTAL.
target_worker.execute_model.assert_called_once_with(execute_model_req) | ||
|
||
@torch.inference_mode() | ||
def test_populate_seq_ids_with_bonus_tokens(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one edge case is where a sequence is skipped but is present in seq_with_bonus_token_in_last_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this case in Forward pass # 2.
Awesome. Will take a look tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we manually verify the following?
- When bonus tokens are enabled and the same model is used for draft and target, we get 100% draft acceptance rate. This indicates that the KV of the draft model is ~equal to the KV of the target model.
# Also reduce num_computed_tokens by 1 since we are not | ||
# including the last output token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add more comment here motivating this? I thought the worker's themselves don't take this into account unless chunked prefill is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I included this to keep the value consistent. It probably not needed since it gets used only for chunked prefill. However I see this being updated elsewhere in the sd code e.g. https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/vllm/spec_decode/draft_model_runner.py?L116 . Hence I added this. Is this confusing? Should I remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha -- yeah can leave in. Add a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a note.
|
||
|
||
@torch.inference_mode() | ||
def test_expand_execute_model_request_for_bonus_tokens(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need both this test and test_same_output_for_multi_step_with_batch_expansion
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine to have both but we should add something in the docstring which explains the difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this test as well as test_filter_model_output. Both should be covered by test_same_output_for_multi_step_with_batch_expansion as you suggested.
output_indices_to_retain = random.sample(range(num_steps), | ||
max(1, num_steps // 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused why we sample from range(num_steps)
. shouldn't it be range(batch_size)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be range(batch_size). I removed the test though in favor of test_same_output_for_multi_step_with_batch_expansion
# Forward Pass : 0 | ||
# Set the last token ID to -1 for all indices not in | ||
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in | ||
# those indices. | ||
accepted_token_ids[mask, -1:] = -1 | ||
worker = SpecDecodeWorker(draft_worker, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we break the three tests in this test into their own tests? it will be easier to follow and debug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified the test to initialize the internal data structures with some fake data and then run a forward pass to simulate all 3 cases. PTAL. It should be simpler to follow now while covering all the cases.
Oh I just saw your response
This is awesome! so exciting to see this working! |
Thanks for the review. Addressed the comments. PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
# Also reduce num_computed_tokens by 1 since we are not | ||
# including the last output token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha -- yeah can leave in. Add a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. Addressed your comment and rebased. Should be ready to merge once the tests complete.
# Also reduce num_computed_tokens by 1 since we are not | ||
# including the last output token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a note.
Merged! |
…or KV cache based models (vllm-project#5765) (cherry picked from commit ae151d7)
…or KV cache based models (vllm-project#5765)
…or KV cache based models (vllm-project#5765) Signed-off-by: Alvant <[email protected]>
…or KV cache based models (vllm-project#5765) Signed-off-by: LeiWang1999 <[email protected]>
FILL IN THE PR DESCRIPTION HERE
FIX #4212
In this PR we make the following changes
Some numbers from e2e tests. Note that the e2e tests don't use cuda graphs. The draft model is JackFram/llama-68m and the target model is JackFram/llama-160m. We use a batch size of 64. Completion time for num_speculation = 1 shows ~33% speedup
w/o bonus Processed prompts: 100%|█████████████████64/64 [00:06<00:00, 10.13it/s, est. speed input: 78.48 toks/s, output: 2592.40 toks/s]
with bonus Processed prompts: 100%|█████████████████████ 64/64 [00:04<00:00, 15.50it/s, est. speed input: 120.13 toks/s, output: 3968.28 toks/s]