Skip to content

Commit 8437bae

Browse files
authored
[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)
1 parent f48c679 commit 8437bae

21 files changed

+2786
-215
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ steps:
2828
num_gpus: 2 # only support 1 or 2 for now.
2929

3030
- label: Engine Test
31-
command: pytest -v -s engine
31+
command: pytest -v -s engine test_sequence.py
3232

3333
- label: Entrypoints Test
3434
command: pytest -v -s entrypoints
@@ -52,6 +52,9 @@ steps:
5252
- label: Worker Test
5353
command: pytest -v -s worker
5454

55+
- label: Speculative decoding tests
56+
command: pytest -v -s spec_decode
57+
5558
- label: LoRA Test
5659
command: pytest -v -s lora --forked
5760

File renamed without changes.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import pytest
3+
4+
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
5+
6+
from .utils import mock_worker, create_seq_group_metadata_from_prompts
7+
8+
9+
@pytest.mark.parametrize('num_target_seq_ids', [100])
10+
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
11+
"""Verify all new sequence ids are greater than all input
12+
seq ids.
13+
"""
14+
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
15+
16+
all_seq_ids = [
17+
[1, 3, 5, 7],
18+
list(range(100)) + [0],
19+
[100],
20+
]
21+
22+
for seq_ids in all_seq_ids:
23+
max_seq_id = max(seq_ids)
24+
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
25+
for _ in range(num_target_seq_ids):
26+
assert next(iterator) > max_seq_id
27+
28+
29+
@pytest.mark.parametrize('k', [1, 2, 6])
30+
def test_get_token_ids_to_score(k: int):
31+
"""Verify correct tokens are selected for scoring.
32+
"""
33+
proposal_token_ids = torch.tensor(
34+
list(range(k)),
35+
dtype=torch.int64,
36+
device='cuda',
37+
)
38+
39+
expected_output = [
40+
[],
41+
]
42+
for i in range(proposal_token_ids.shape[0]):
43+
expected_output.append(proposal_token_ids[:i + 1].tolist())
44+
45+
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
46+
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access
47+
48+
actual_output = [
49+
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
50+
]
51+
52+
assert actual_output == expected_output
53+
54+
55+
@pytest.mark.parametrize('k', [1, 2, 6])
56+
def test_create_single_target_seq_group_metadata(k: int):
57+
"""Verify correct creation of a batch-expanded seq group metadata.
58+
"""
59+
60+
prompt_tokens = [1, 2, 3]
61+
prev_output_tokens = [4, 5, 6]
62+
63+
token_ids = list(range(k))
64+
65+
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
66+
67+
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
68+
token_ids)
69+
70+
block_size = 32
71+
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
72+
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
73+
[prev_output_tokens], [num_tokens_processed])[0]
74+
75+
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
76+
target_seq_id = 100
77+
78+
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
79+
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
80+
input_seq_group_metadata,
81+
input_seq_id,
82+
target_seq_id,
83+
token_ids,
84+
)
85+
86+
assert output.request_id == input_seq_group_metadata.request_id
87+
assert len(output.seq_data) == 1
88+
assert output.seq_data[target_seq_id].get_prompt_token_ids(
89+
) == prompt_tokens
90+
assert output.seq_data[target_seq_id].get_output_token_ids(
91+
) == prev_output_tokens + token_ids
92+
93+
assert len(output.block_tables) == 1
94+
assert output.block_tables[
95+
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]

tests/spec_decode/test_metrics.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
import math
3+
import pytest
4+
5+
from unittest.mock import MagicMock
6+
7+
from vllm.spec_decode.metrics import AsyncMetricsCollector
8+
9+
10+
def test_initial_call_returns_none():
11+
"""Expect first call to get metrics to return None.
12+
"""
13+
rej_sampler = MagicMock()
14+
rej_sampler.num_accepted_tokens = torch.tensor(0,
15+
dtype=torch.long,
16+
device='cuda')
17+
rej_sampler.num_emitted_tokens = torch.tensor(0,
18+
dtype=torch.long,
19+
device='cuda')
20+
rej_sampler.num_draft_tokens = 0
21+
22+
collector = AsyncMetricsCollector(rej_sampler)
23+
collector.init_gpu_tensors(rank=0)
24+
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
25+
assert maybe_metrics is None
26+
27+
28+
def test_second_call_returns_metrics():
29+
"""Expect second call to not return None.
30+
"""
31+
rej_sampler = MagicMock()
32+
rej_sampler.num_accepted_tokens = torch.tensor(0,
33+
dtype=torch.long,
34+
device='cuda')
35+
rej_sampler.num_emitted_tokens = torch.tensor(0,
36+
dtype=torch.long,
37+
device='cuda')
38+
rej_sampler.num_draft_tokens = 0
39+
40+
collect_interval_s = 5.0
41+
timer = MagicMock()
42+
timer.side_effect = [
43+
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
44+
]
45+
46+
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
47+
timer=timer,
48+
collect_interval_s=collect_interval_s)
49+
collector.init_gpu_tensors(rank=0)
50+
_ = collector.maybe_collect_rejsample_metrics(k=5)
51+
metrics = collector.maybe_collect_rejsample_metrics(k=5)
52+
assert metrics is not None
53+
54+
55+
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
56+
def test_nonzero_rank_noop(rank):
57+
"""Verify nonzero ranks don't collect metrics.
58+
"""
59+
rej_sampler = MagicMock()
60+
rej_sampler.num_accepted_tokens = torch.tensor(0,
61+
dtype=torch.long,
62+
device='cuda')
63+
rej_sampler.num_emitted_tokens = torch.tensor(0,
64+
dtype=torch.long,
65+
device='cuda')
66+
rej_sampler.num_draft_tokens = 0
67+
68+
collector = AsyncMetricsCollector(rej_sampler)
69+
collector.init_gpu_tensors(rank=rank)
70+
_ = collector.maybe_collect_rejsample_metrics(k=5)
71+
metrics = collector.maybe_collect_rejsample_metrics(k=5)
72+
assert metrics is None
73+
74+
75+
def test_noop_until_time():
76+
"""Verify metrics aren't collected until enough time passes.
77+
"""
78+
rej_sampler = MagicMock()
79+
rej_sampler.num_accepted_tokens = torch.tensor(0,
80+
dtype=torch.long,
81+
device='cuda')
82+
rej_sampler.num_emitted_tokens = torch.tensor(0,
83+
dtype=torch.long,
84+
device='cuda')
85+
rej_sampler.num_draft_tokens = 0
86+
87+
collect_interval_s = 5.0
88+
timer = MagicMock()
89+
timer.side_effect = [
90+
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
91+
collect_interval_s + 0.1, collect_interval_s + 0.1
92+
]
93+
94+
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
95+
timer=timer,
96+
collect_interval_s=collect_interval_s)
97+
collector.init_gpu_tensors(rank=0)
98+
99+
_ = collector.maybe_collect_rejsample_metrics(k=5)
100+
metrics = collector.maybe_collect_rejsample_metrics(k=5)
101+
assert metrics is None
102+
103+
_ = collector.maybe_collect_rejsample_metrics(k=5)
104+
metrics = collector.maybe_collect_rejsample_metrics(k=5)
105+
assert metrics is not None
106+
107+
108+
@pytest.mark.parametrize("has_data", [True, False])
109+
def test_initial_metrics_has_correct_values(has_data: bool):
110+
"""Test correctness of metrics data.
111+
"""
112+
if has_data:
113+
num_accepted_tokens = 103
114+
num_emitted_tokens = 104
115+
num_draft_tokens = 105
116+
else:
117+
num_accepted_tokens = 0
118+
num_emitted_tokens = 0
119+
num_draft_tokens = 0
120+
k = 5
121+
122+
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
123+
num_draft_tokens, k)
124+
125+
rej_sampler = MagicMock()
126+
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
127+
dtype=torch.long,
128+
device='cuda')
129+
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
130+
dtype=torch.long,
131+
device='cuda')
132+
rej_sampler.num_draft_tokens = num_draft_tokens
133+
134+
collect_interval_s = 5.0
135+
timer = MagicMock()
136+
timer.side_effect = [
137+
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
138+
]
139+
140+
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
141+
timer=timer,
142+
collect_interval_s=collect_interval_s)
143+
collector.init_gpu_tensors(rank=0)
144+
_ = collector.maybe_collect_rejsample_metrics(k)
145+
metrics = collector.maybe_collect_rejsample_metrics(k)
146+
147+
assert metrics.num_spec_tokens == k
148+
assert metrics.accepted_tokens == num_accepted_tokens
149+
assert metrics.draft_tokens == num_draft_tokens
150+
assert metrics.emitted_tokens == num_emitted_tokens
151+
152+
if has_data:
153+
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
154+
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
155+
else:
156+
assert math.isnan(metrics.draft_acceptance_rate)
157+
assert math.isnan(metrics.system_efficiency)

0 commit comments

Comments
 (0)