Skip to content

Conversation

sroy745
Copy link
Collaborator

@sroy745 sroy745 commented Jun 22, 2024

FILL IN THE PR DESCRIPTION HERE

FIX #4212

In this PR we make the following changes

  1. Update the spec_decode_worker to keep track of the sequence_ids which we were assigned bonus token ids in their last forward pass. We record this only for the MultiStepWorker since other Worker types don't utilize the KV cache for token generation. Currently we don't clear out the sequence ids from this list even on sequence termination. We need a way to get notified on sequence termination and remove those sequence ids
  2. Updated the MultiStepWorker to expand the batch during step 0. During batch expansion we check to see which of the sequence ids were assigned bonus tokens in their last forward pass. For each of those sequences we add a new sequence (with the same seq_id) without the bonus token. Once the forward pass for step 0 is completed we filter out the response to retain only those responses which correspond to the original sequences.
  3. Added a flag --disable-bonus-tokens-in-kv-cache to enable/disable bonus tokens for MultiStepWorker.

Some numbers from e2e tests. Note that the e2e tests don't use cuda graphs. The draft model is JackFram/llama-68m and the target model is JackFram/llama-160m. We use a batch size of 64. Completion time for num_speculation = 1 shows ~33% speedup

  • w/o bonus Processed prompts: 100%|█████████████████64/64 [00:06<00:00, 10.13it/s, est. speed input: 78.48 toks/s, output: 2592.40 toks/s]

  • with bonus Processed prompts: 100%|█████████████████████ 64/64 [00:04<00:00, 15.50it/s, est. speed input: 120.13 toks/s, output: 3968.28 toks/s]

@sroy745 sroy745 marked this pull request as draft June 22, 2024 19:58
@sroy745 sroy745 changed the title Enabling bonus token in speculative decoding for KV cache based models [Speculative Decoding] [WIP] Enabling bonus token in speculative decoding for KV cache based models Jun 24, 2024
@sroy745 sroy745 marked this pull request as ready for review June 24, 2024 17:35
@sroy745 sroy745 changed the title [Speculative Decoding] [WIP] Enabling bonus token in speculative decoding for KV cache based models [Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models Jun 24, 2024
@cadedaniel
Copy link
Collaborator

LMK once it's ready for review @sroy745

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approach looks good. seems this will conflict with #5799. we should work out a way to combine both.

and batch sizes when bonus token acceptance is enabled. It ensures
correctness by comparing the output of speculative decoding with the baseline.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test looks good. Can you manually check the draft acceptance rate? for the same draft/target model it should be 100%. without your fix (and with bonus token enabled) the acceptance rate goes to like 80%.

Ideally we have a test for this, not sure how easy that is..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought of adding such an e2e test but I could find an easy way to access the metrics_collector and the stats. As you suggested I added an e2e test in test_multistep_correctness.py with both draft and target model as same ("JackFram/llama-68m") and number of speculation as 3. The System efficiency with bonus token enabled is 1 and with bonus token disabled is 0.75

With bonus token
INFO 07-08 06:27:27 metrics.py:316] Speculative metrics: Draft acceptance rate: 1.000, System efficiency: 1.000, Number of speculative tokens: 3, Number of accepted tokens: 21696, Number of draft tokens tokens: 21696, Number of emitted tokens tokens: 28928.

Without bonus token
Speculative metrics: Draft acceptance rate: 1.000, System efficiency: 0.750, Number of speculative tokens: 3, Number of accepted tokens: 21696, Number of draft tokens tokens: 21696, Number of emitted tokens tokens: 21696.

indices_of_original_sequence_groups = []
for seq_group in execute_model_req.seq_group_metadata_list:
seq_ids_with_bonus_tokens = []
for seq_id, seq_data in seq_group.seq_data.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approach looks good. pretty messy but not anything easy to fix that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to simplify it a bit by moving some of the logic to 2 helper functions _shallow_copy_seq_group_metadata and _copy_seq_metadata_excluding_last_token. PTAL.

target_worker.execute_model.assert_called_once_with(execute_model_req)

@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one edge case is where a sequence is skipped but is present in seq_with_bonus_token_in_last_step

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this case in Forward pass # 2.

@cadedaniel
Copy link
Collaborator

Awesome. Will take a look tomorrow.

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we manually verify the following?

  • When bonus tokens are enabled and the same model is used for draft and target, we get 100% draft acceptance rate. This indicates that the KV of the draft model is ~equal to the KV of the target model.

Comment on lines +278 to +279
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more comment here motivating this? I thought the worker's themselves don't take this into account unless chunked prefill is enabled.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I included this to keep the value consistent. It probably not needed since it gets used only for chunked prefill. However I see this being updated elsewhere in the sd code e.g. https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/vllm/spec_decode/draft_model_runner.py?L116 . Hence I added this. Is this confusing? Should I remove it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha -- yeah can leave in. Add a comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note.



@torch.inference_mode()
def test_expand_execute_model_request_for_bonus_tokens():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need both this test and test_same_output_for_multi_step_with_batch_expansion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine to have both but we should add something in the docstring which explains the difference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this test as well as test_filter_model_output. Both should be covered by test_same_output_for_multi_step_with_batch_expansion as you suggested.

Comment on lines 670 to 671
output_indices_to_retain = random.sample(range(num_steps),
max(1, num_steps // 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused why we sample from range(num_steps). shouldn't it be range(batch_size)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be range(batch_size). I removed the test though in favor of test_same_output_for_multi_step_with_batch_expansion

Comment on lines 706 to 711
# Forward Pass : 0
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we break the three tests in this test into their own tests? it will be easier to follow and debug

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified the test to initialize the internal data structures with some fake data and then run a forward pass to simulate all 3 cases. PTAL. It should be simpler to follow now while covering all the cases.

@cadedaniel
Copy link
Collaborator

Oh I just saw your response

Yeah I thought of adding such an e2e test but I could find an easy way to access the metrics_collector and the stats. As you suggested I added an e2e test in test_multistep_correctness.py with both draft and target model as same ("JackFram/llama-68m") and number of speculation as 3. The System efficiency with bonus token enabled is 1 and with bonus token disabled is 0.75

With bonus token
INFO 07-08 06:27:27 metrics.py:316] Speculative metrics: Draft acceptance rate: 1.000, System efficiency: 1.000, Number of speculative tokens: 3, Number of accepted tokens: 21696, Number of draft tokens tokens: 21696, Number of emitted tokens tokens: 28928.

Without bonus token
Speculative metrics: Draft acceptance rate: 1.000, System efficiency: 0.750, Number of speculative tokens: 3, Number of accepted tokens: 21696, Number of draft tokens tokens: 21696, Number of emitted tokens tokens: 21696.

This is awesome! so exciting to see this working!

@sroy745
Copy link
Collaborator Author

sroy745 commented Jul 10, 2024

Thanks for the review. Addressed the comments. PTAL

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Comment on lines +278 to +279
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha -- yeah can leave in. Add a comment?

Copy link
Collaborator Author

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. Addressed your comment and rebased. Should be ready to merge once the tests complete.

Comment on lines +278 to +279
# Also reduce num_computed_tokens by 1 since we are not
# including the last output token.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note.

@cadedaniel cadedaniel merged commit ae151d7 into vllm-project:main Jul 10, 2024
@cadedaniel
Copy link
Collaborator

Merged!

adityagoel14 pushed a commit to adityagoel14/vllm-torchrun-test that referenced this pull request Jul 11, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Speculative decoding] [Performance]: Re-enable bonus tokens

2 participants