Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 174 additions & 5 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
Expand Down Expand Up @@ -64,6 +66,86 @@ def _create_proposer(
device=current_platform.device_type)


def test_prepare_next_token_ids():
"""
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
Each will produce a device tensor of next_token_ids, taking as input
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
or the CPU python list[list[int]] with the rejected tokens removed.
"""
device = torch.device(current_platform.device_type)

num_requests = 4
num_speculative_tokens = 4
batch_spec = BatchSpec(
seq_lens=[num_speculative_tokens + 1] * num_requests,
query_lens=[num_speculative_tokens + 1] * num_requests,
)

req_ids = [f"req_{i+1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100

mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
mock_requests = {}
for req_id in req_ids:
mock_request = mock.MagicMock(spec=CachedRequestState)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
mock_request.num_computed_tokens = 0
mock_requests[req_id] = mock_request

sampled_token_ids = [
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
[0, 1, 2, 3, 4], # all accepted, "4" sampled
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
[-1, -1, -1, -1, -1] # this request will be discarded
]
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
dtype=torch.int32,
device=device)
sampled_token_ids_cpu = [[i for i in seq if i != -1]
for seq in sampled_token_ids]

expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
dtype=torch.int32,
device=device)

proposer = _create_proposer("eagle", num_speculative_tokens)

next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
sampled_token_ids_cpu, mock_requests, mock_input_batch,
mock_num_scheduled_tokens)

assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
num_discarded_reqs = 1

expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
dtype=torch.int32,
device=device)

next_token_ids_from_padded, valid_sampled_tokens_count = \
proposer.prepare_next_token_ids_padded(
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
mock_input_batch, discarded_req_indices, num_discarded_reqs)

assert torch.equal(next_token_ids_from_padded,
expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count,
expected_valid_sampled_tokens_count)


def test_prepare_inputs():
"""
cu_target_query_lens: [0, a, a + b, a + b + c]
Expand All @@ -90,10 +172,24 @@ def test_prepare_inputs():
device=device,
)

# Rejected tokens per request: [1, 3, 2]
num_rejected_tokens = torch.tensor([1, 3, 2],
dtype=torch.int32,
device=device)
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
# from the previous iteration, and the last token is the bonus token sampled
# from the base model.
num_draft_tokens = [3, 6, 4] # one less than query_lens
# num rejected tokens is [1, 3, 2]
ACCEPT_TOKEN = 0
BONUS_TOKEN = 1
REJECT_TOKEN = -1
sampled_token_ids = [
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
[
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
],
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
]
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
for seq in sampled_token_ids]

# Expected calculations:
# query_len_per_req = [4, 7, 5]
Expand Down Expand Up @@ -125,14 +221,85 @@ def test_prepare_inputs():
proposer = _create_proposer("eagle", 1)

updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, num_rejected_tokens.cpu())
common_attn_metadata, sampled_token_ids, num_draft_tokens)

assert torch.equal(updated_metadata.query_start_loc,
expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)


def test_prepare_inputs_padded():
"""
Input scenario is 3 requests with num_speculative_tokens == 2 and:
- Request 1: query_len = 3, rejected = 1
- Request 2: query_len = 3, rejected = 0
- Request 3: query_len = 3, rejected = 2

Expected outputs:
token_indices: [0, 1, 2,
3, 4, 5,
6, 7, 8]
Reason: Deferred computation should not disturb the original indices.

token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
"""

device = torch.device(current_platform.device_type)

expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.int32,
device=device)
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
dtype=torch.int32,
device=device)

num_speculative_tokens = 2
batch_spec = BatchSpec(
seq_lens=[3, 3, 3],
query_lens=[3, 3, 3],
)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
dtype=torch.int32,
device=device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids=[[0] * num_speculative_tokens] * 3,
device=device,
)

# num_rejected_tokens = [1, 0, 2]
# num_draft_tokens = [2, 2, 2]
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
dtype=torch.int32,
device=device)

proposer = _create_proposer("eagle", num_speculative_tokens)

output_metadata, token_indices, token_indices_to_sample = \
proposer.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)

assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc,
expected_query_start_loc)
assert torch.equal(token_indices, expected_token_indices)
assert torch.equal(token_indices_to_sample,
expected_token_indices_to_sample)


@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
Expand Down Expand Up @@ -373,6 +540,7 @@ def create_deterministic_logits(token_ids):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

Expand Down Expand Up @@ -526,6 +694,7 @@ def create_deterministic_logits(token_ids, k: int):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)
Expand Down
5 changes: 5 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class SpeculativeConfig:
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this flag necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to maintain both pathways, so that this PR's changes can be disabled if needed. Since it is not clear what the perf will look like once new backend support is added and piecewise cuda graphs are enabled for them. This flag makes it easy to compare the perf so we can easily decide if this change should be rolled back (or have the default swapped) in the future

"""Disable input padding for speculative decoding. If set to True,
speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""

# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None
Expand Down
Loading