Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
Expand Down Expand Up @@ -90,10 +91,24 @@ def test_prepare_inputs():
device=device,
)

# Rejected tokens per request: [1, 3, 2]
num_rejected_tokens = torch.tensor([1, 3, 2],
dtype=torch.int32,
device=device)
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
# from the previous iteration, and the last token is the bonus token sampled
# from the base model.
num_draft_tokens = [3, 6, 4] # one less than query_lens
# num rejected tokens is [1, 3, 2]
ACCEPT_TOKEN = 0
BONUS_TOKEN = 1
REJECT_TOKEN = -1
sampled_token_ids = [
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
[
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
],
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
]
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
for seq in sampled_token_ids]

# Expected calculations:
# query_len_per_req = [4, 7, 5]
Expand Down Expand Up @@ -125,14 +140,85 @@ def test_prepare_inputs():
proposer = _create_proposer("eagle", 1)

updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, num_rejected_tokens.cpu())
common_attn_metadata, sampled_token_ids, num_draft_tokens)

assert torch.equal(updated_metadata.query_start_loc,
expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)


def test_prepare_inputs_deferred():
"""
Input scenario is 3 requests with num_speculative_tokens == 2 and:
- Request 1: query_len = 3, rejected = 1
- Request 2: query_len = 3, rejected = 0
- Request 3: query_len = 3, rejected = 2

Expected outputs:
token_indices: [0, 1, 2,
3, 4, 5,
6, 7, 8]
Reason: Deferred computation should not disturb the original indices.

token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
"""

device = torch.device(current_platform.device_type)

expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.int32,
device=device)
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
dtype=torch.int32,
device=device)

num_speculative_tokens = 2
batch_spec = BatchSpec(
seq_lens=[3, 3, 3],
query_lens=[3, 3, 3],
)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
dtype=torch.int32,
device=device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids=[[0] * num_speculative_tokens] * 3,
device=device,
)

# num_rejected_tokens = [1, 0, 2]
# num_draft_tokens = [2, 2, 2]
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
dtype=torch.int32,
device=device)

proposer = _create_proposer("eagle", num_speculative_tokens)

output_metadata, token_indices, token_indices_to_sample = \
proposer.prepare_inputs_deferred(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)

assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc,
expected_query_start_loc)
assert torch.equal(token_indices, expected_token_indices)
assert torch.equal(token_indices_to_sample,
expected_token_indices_to_sample)


@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
Expand Down
5 changes: 5 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,11 @@ class SpeculativeConfig:
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_batch: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add assertion for non eagle and attention backends that are not supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really think this is necessary. For non-eagle speculation, the value of the flag is always ignored. Since the default is false (padded), it doesn't seem right to have users opt-in to disabling it when they use non-eagle spec. I think the easiest solution is the current one, to ignore it when doing non-eagle spec. Is there another behaviour you had in mind?

As for attention backends, this should be a strictly more-available feature than the current non-padded speculative decoding. So the existing assertions in the proposer should be sufficient.

"""Disable input padding for speculative decoding. If set to True,
speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""

# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None
Expand Down
Loading