Skip to content

Commit 05fcd1b

Browse files
authored
[V1][Perf] Faster incremental detokenization (#15137)
Signed-off-by: Nick Hill <[email protected]>
1 parent 7c02d6a commit 05fcd1b

File tree

7 files changed

+316
-144
lines changed

7 files changed

+316
-144
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ blake3
88
py-cpuinfo
99
transformers >= 4.51.1
1010
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
11-
tokenizers >= 0.19.1 # Required for Llama 3.
11+
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
1212
protobuf # Required by LlamaTokenizer.
1313
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
1414
aiohttp

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test
3535
datamodel_code_generator # required for minicpm3 test
3636
lm-eval[api]==0.4.8 # required for model evaluation test
3737
transformers==4.51.1
38+
tokenizers==0.21.1
3839
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
3940
# quantization
4041
bitsandbytes>=0.45.3

requirements/test.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,10 @@ tiktoken==0.7.0
624624
# mistral-common
625625
timm==1.0.11
626626
# via -r requirements/test.in
627-
tokenizers==0.21.0
628-
# via transformers
627+
tokenizers==0.21.1
628+
# via
629+
# -r requirements/test.in
630+
# transformers
629631
torch==2.6.0
630632
# via
631633
# -r requirements/test.in

tests/lora/test_llama_tp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
4747
]
4848
sampling_params = vllm.SamplingParams(temperature=0,
4949
max_tokens=256,
50+
skip_special_tokens=False,
5051
stop=["[/assistant]"])
5152
outputs = llm.generate(
5253
prompts,

tests/tokenization/test_detokenize.py

Lines changed: 137 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
from typing import Any, Optional
55

66
import pytest
7-
from transformers import AutoTokenizer
7+
from transformers import (AutoTokenizer, PreTrainedTokenizer,
8+
PreTrainedTokenizerFast)
89

910
from vllm.inputs import token_inputs
1011
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
11-
from vllm.transformers_utils.detokenizer import (Detokenizer,
12-
detokenize_incrementally)
12+
from vllm.transformers_utils.detokenizer import Detokenizer
1313
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
1414
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
15+
from vllm.v1.engine import EngineCoreRequest
16+
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
17+
IncrementalDetokenizer,
18+
SlowIncrementalDetokenizer)
19+
20+
SPECIAL_TOKS_TRUTH = [
21+
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
22+
]
1523

1624
TRUTH = [
1725
"Hello here, this is a simple test",
@@ -22,7 +30,8 @@
2230
# incomplete UTF-8 characters
2331
# see https://github.com/vllm-project/vllm/pull/9625
2432
"ပုံပြင်လေးပြောပြပါ်",
25-
]
33+
] + SPECIAL_TOKS_TRUTH
34+
2635
TOKENIZERS = [
2736
"facebook/opt-125m",
2837
"gpt2",
@@ -38,26 +47,37 @@
3847
]
3948

4049

41-
def _run_incremental_decode(tokenizer, all_input_ids,
42-
skip_special_tokens: bool, starting_index: int):
43-
decoded_text = ""
44-
offset = 0
45-
token_offset = 0
46-
prev_tokens = None
47-
for i in range(starting_index, len(all_input_ids)):
48-
new_tokens, text, offset, token_offset = detokenize_incrementally(
49-
tokenizer,
50-
all_input_ids[:i + 1],
51-
prev_tokens,
52-
offset,
53-
token_offset,
54-
skip_special_tokens=skip_special_tokens)
55-
decoded_text += text
56-
if prev_tokens is None:
57-
prev_tokens = new_tokens
58-
else:
59-
prev_tokens += new_tokens
60-
return decoded_text
50+
def _run_incremental_decode(tokenizer,
51+
all_input_ids,
52+
skip_special_tokens: bool,
53+
starting_index: int,
54+
spaces_between_special_tokens: bool = True,
55+
fast: Optional[bool] = None):
56+
57+
prompt_token_ids = all_input_ids[:starting_index]
58+
59+
params = SamplingParams(
60+
skip_special_tokens=skip_special_tokens,
61+
spaces_between_special_tokens=spaces_between_special_tokens,
62+
)
63+
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
64+
params, None, 0.0, None)
65+
66+
if fast is None:
67+
detokenizer = IncrementalDetokenizer.from_new_request(
68+
tokenizer, request)
69+
elif fast:
70+
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
71+
else:
72+
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
73+
74+
output_text = ""
75+
for i, token_id in enumerate(all_input_ids[starting_index:]):
76+
detokenizer.update([token_id], False)
77+
finished = i == len(all_input_ids) - 1
78+
output_text += detokenizer.get_next_output_text(finished, delta=True)
79+
80+
return output_text, detokenizer.output_token_ids
6181

6282

6383
@pytest.fixture
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
85105
starting_index = 0
86106
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
87107

88-
decoded_text = _run_incremental_decode(tokenizer,
89-
all_input_ids,
90-
skip_special_tokens=True,
91-
starting_index=starting_index)
108+
decoded_text, out_ids = _run_incremental_decode(
109+
tokenizer,
110+
all_input_ids,
111+
skip_special_tokens=True,
112+
starting_index=starting_index)
92113
assert decoded_text == truth
114+
assert out_ids == all_input_ids[starting_index:]
93115

94116

95117
@pytest.fixture
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
106128
@pytest.mark.parametrize("with_prompt", [True, False])
107129
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
108130
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
109-
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
131+
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
132+
@pytest.mark.parametrize("fast", (True, False))
133+
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
134+
spaces_between_special_tokens, fast):
135+
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
136+
pytest.skip()
137+
138+
if skip_special_tokens and not spaces_between_special_tokens:
139+
pytest.skip()
140+
141+
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
142+
# Fix up inconsistency in fast/slow tokenizer behaviour.
143+
tokenizer.add_special_tokens({
144+
"additional_special_tokens": [
145+
at for at in
146+
tokenizer._tokenizer.get_added_tokens_decoder().values()
147+
if at.special
148+
]
149+
})
150+
151+
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
152+
else {"spaces_between_special_tokens": spaces_between_special_tokens}
153+
154+
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
155+
if tokenizer.bos_token_id is not None:
156+
truth_tokens.insert(0, tokenizer.bos_token_id)
157+
truth_tokens.append(tokenizer.eos_token_id)
158+
159+
new_truth = tokenizer.decode(truth_tokens,
160+
skip_special_tokens=skip_special_tokens,
161+
**extra_decode_args)
162+
110163
if with_prompt:
111-
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
112-
prompt_input_ids = truth_tokens[:len(truth) // 2]
113-
generated_input_ids = truth_tokens[len(truth) // 2:]
164+
num_prompt_tokens = len(
165+
tokenizer(truth[:len(truth) // 2],
166+
add_special_tokens=False).input_ids)
167+
if tokenizer.bos_token_id is not None:
168+
num_prompt_tokens += 1
169+
170+
prompt_input_ids = truth_tokens[:num_prompt_tokens]
171+
generated_input_ids = truth_tokens[num_prompt_tokens:]
114172
all_input_ids = prompt_input_ids + generated_input_ids
115173
starting_index = len(prompt_input_ids)
116174
prompt = tokenizer.decode(prompt_input_ids,
117-
skip_special_tokens=skip_special_tokens)
118-
generated = truth[len(prompt):]
175+
skip_special_tokens=skip_special_tokens,
176+
**extra_decode_args)
177+
178+
generated = new_truth[len(prompt):]
119179
else:
120-
generated = truth
180+
generated = new_truth
121181
starting_index = 0
122-
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
123-
if skip_special_tokens:
124-
if tokenizer.bos_token_id is not None:
125-
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
126-
starting_index += 1
127-
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
182+
all_input_ids = truth_tokens
128183

129-
decoded_text = _run_incremental_decode(
184+
decoded_text, out_ids = _run_incremental_decode(
130185
tokenizer,
131186
all_input_ids,
132187
skip_special_tokens=skip_special_tokens,
133-
starting_index=starting_index)
188+
starting_index=starting_index,
189+
spaces_between_special_tokens=spaces_between_special_tokens,
190+
fast=fast)
134191

135192
assert decoded_text == generated
193+
assert out_ids == all_input_ids[starting_index:]
136194

137-
decoded_text = _run_incremental_decode(
195+
196+
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
197+
@pytest.mark.parametrize("fast", (True, False))
198+
def test_oov_decode(tokenizer, fast):
199+
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
200+
pytest.skip()
201+
202+
decoded_text, out_ids = _run_incremental_decode(
138203
tokenizer, [len(tokenizer)],
139-
skip_special_tokens=skip_special_tokens,
140-
starting_index=starting_index)
204+
skip_special_tokens=True,
205+
starting_index=0,
206+
spaces_between_special_tokens=True,
207+
fast=fast)
141208

142209
assert decoded_text == ''
210+
assert out_ids == [len(tokenizer)]
143211

144212

145213
@pytest.fixture
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
165233
@pytest.fixture(name="complete_sequence_token_ids")
166234
def create_complete_sequence_token_ids(complete_sequence: str,
167235
tokenizer) -> list[int]:
168-
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
169-
return complete_sequence_token_ids
236+
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
170237

171238

172239
def create_sequence(prompt_token_ids=None):
173-
prompt_token_ids = prompt_token_ids or [1]
240+
prompt_token_ids = prompt_token_ids or []
174241
return Sequence(
175242
seq_id=0,
176-
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
243+
inputs=token_inputs(prompt_token_ids),
177244
block_size=16,
178245
)
179246

@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
224291
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
225292
assert sequential_result != "".join(sequential_logprobs_text_other_token)
226293

227-
if skip_special_tokens:
294+
if not skip_special_tokens:
228295
# Text for logprobs for the chosen token should be the same as the
229296
# generated text. Note that this will only be true if we skip
230297
# special tokens.
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
233300

234301
@pytest.mark.parametrize("complete_sequence", TRUTH)
235302
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
236-
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
303+
def test_decode_prompt_logprobs(complete_sequence: str,
304+
complete_sequence_token_ids: list[int],
237305
detokenizer: Detokenizer):
306+
307+
# We want to use skip_special_tokens=False here but Mistral tokenizers
308+
# don't support that.
309+
if complete_sequence not in SPECIAL_TOKS_TRUTH:
310+
skip_special_tokens = True
311+
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
312+
MistralTokenizer):
313+
skip_special_tokens = False
314+
else:
315+
pytest.skip("MistralTokenizers don't support "
316+
"skip_special_tokens=False")
317+
return
238318
"""Verify Detokenizer decodes prompt logprobs correctly."""
239-
sampling_params = SamplingParams(skip_special_tokens=True,
319+
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
240320
prompt_logprobs=1)
241321

242322
# Run sequentially.
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
256336
# decoded_prompt_logprobs doesn't contain the first token.
257337
token_ids = complete_sequence_token_ids
258338
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
259-
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
260-
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
339+
text_full = tokenizer.decode(token_ids,
340+
skip_special_tokens=skip_special_tokens)
341+
text_first = tokenizer.decode(token_ids[0],
342+
skip_special_tokens=skip_special_tokens)
261343
text = text_full[len(text_first):]
262344

263345
# Text for logprobs for the chosen token should be the same as the

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ def __init__(self, tokenizer: AnyTokenizer):
7070
"Mistral Tool Parser could not locate the tool call token in "
7171
"the tokenizer!")
7272

73+
def adjust_request(
74+
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
75+
if request.tools and request.tool_choice != 'none':
76+
# do not skip special tokens because mistral uses the special
77+
# tokens to indicate the start and end of the tool calls
78+
# information.
79+
request.skip_special_tokens = False
80+
return request
81+
7382
def extract_tool_calls(
7483
self,
7584
model_output: str,

0 commit comments

Comments
 (0)