Skip to content

Commit cd57935

Browse files
hj-mistralnjhill
andauthored
[V1] Do not detokenize if sampling param detokenize is False (#14224)
Signed-off-by: Himanshu Jaju <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 9f1710f commit cd57935

File tree

4 files changed

+69
-27
lines changed

4 files changed

+69
-27
lines changed

tests/v1/sample/test_sampling_params_e2e.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
@pytest.fixture(scope="module")
1616
def model() -> LLM:
17-
return LLM(MODEL, enforce_eager=True)
17+
# Disable prefix caching so that we can test prompt logprobs.
18+
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
19+
# is merged
20+
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
1821

1922

2023
def test_n_gt_1(model):
@@ -87,9 +90,33 @@ def test_stop_token_ids(model):
8790

8891
stop_token_ids = [stop_token_id_0, stop_token_id_1]
8992
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
93+
output = model.generate(PROMPT, params)
9094
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
9195

9296

97+
def test_detokenize_false(model):
98+
"""Check that detokenize=False option works."""
99+
100+
output = model.generate(PROMPT, SamplingParams(detokenize=False))
101+
assert len(output[0].outputs[0].token_ids) > 0
102+
assert len(output[0].outputs[0].text) == 0
103+
104+
output = model.generate(
105+
PROMPT, SamplingParams(detokenize=False, logprobs=3,
106+
prompt_logprobs=3))
107+
assert len(output[0].outputs[0].token_ids) > 0
108+
assert len(output[0].outputs[0].text) == 0
109+
110+
prompt_logprobs = output[0].prompt_logprobs
111+
sampled_logprobs = output[0].outputs[0].logprobs
112+
assert len(prompt_logprobs) > 1
113+
assert len(sampled_logprobs) > 1
114+
for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
115+
for logprobs in all_logprobs:
116+
assert 3 <= len(logprobs) <= 4
117+
assert all(lp.decoded_token is None for lp in logprobs.values())
118+
119+
93120
def test_bad_words(model):
94121
"""Check that we respect bad words."""
95122

vllm/v1/engine/detokenizer.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from typing import Optional
55

66
from vllm.engine.output_processor.stop_checker import StopChecker
@@ -16,41 +16,46 @@
1616
class IncrementalDetokenizer:
1717

1818
# Generation data
19-
output_text: str
20-
tokens: list[str]
2119
token_ids: list[int]
22-
prompt_len: int
20+
output_text: str = ""
21+
tokens: list[str] = field(default_factory=list)
22+
prompt_len: int = 0
2323

2424
# Stop strings
25-
stop: list[str]
26-
include_stop_str_in_output: bool
25+
stop: list[str] = field(default_factory=list)
26+
include_stop_str_in_output: bool = False
2727

2828
# Metadata for incremental detokenization
29-
prefix_offset: int
30-
read_offset: int
29+
prefix_offset: int = 0
30+
read_offset: int = 0
3131

3232
# Parameters for detokenization
33-
skip_special_tokens: bool
34-
spaces_between_special_tokens: bool
33+
skip_special_tokens: bool = True
34+
spaces_between_special_tokens: bool = True
3535

36-
# Tokenizer for this request
37-
tokenizer: AnyTokenizer
36+
# Tokenizer for this request,
37+
# None if detokenization is disabled.
38+
tokenizer: Optional[AnyTokenizer] = None
3839

3940
# Accounting for stop string buffering
40-
stop_buffer_length: int
41+
stop_buffer_length: int = 0
4142
_last_output_text_offset: int = 0
4243

4344
@property
4445
def output_token_ids(self) -> list[int]:
45-
return self.token_ids[self.prompt_len:]
46+
return self.token_ids if not self.prompt_len else (
47+
self.token_ids[self.prompt_len:])
4648

4749
@classmethod
4850
def from_new_request(
4951
cls,
50-
tokenizer: AnyTokenizer,
52+
tokenizer: Optional[AnyTokenizer],
5153
request: EngineCoreRequest,
5254
) -> "IncrementalDetokenizer":
5355

56+
if tokenizer is None:
57+
return cls(token_ids=[])
58+
5459
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
5560
tokenizer=tokenizer,
5661
prompt_ids=request.prompt_token_ids,
@@ -66,7 +71,6 @@ def from_new_request(
6671
stop_buffer_length = 0
6772

6873
return cls(
69-
output_text="",
7074
tokens=tokens,
7175
# Detokenizer mutates this list, so need a unique copy.
7276
# NOTE(Nick): could we take ownership of it though?
@@ -93,6 +97,10 @@ def update(self, new_token_ids: list[int]) -> Optional[str]:
9397
Return matched stop string or None.
9498
"""
9599

100+
if self.tokenizer is None:
101+
self.token_ids.extend(new_token_ids)
102+
return None
103+
96104
# 1) Detokenize the new token ids incrementally.
97105
# TODO(woosuk): This method becomes very inefficient when the number of
98106
# new_token_ids is more than 1. We need to optimize this.

vllm/v1/engine/logprobs.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import itertools
4+
from collections.abc import Iterable
45
from dataclasses import dataclass
56
from typing import Optional
67

@@ -13,12 +14,15 @@
1314

1415
logger = init_logger(__name__)
1516

17+
NONES = itertools.repeat(None)
18+
1619

1720
@dataclass
1821
class LogprobsProcessor:
1922

20-
# Tokenizer for this request
21-
tokenizer: AnyTokenizer
23+
# Tokenizer for this request,
24+
# None if detokenization is disabled.
25+
tokenizer: Optional[AnyTokenizer]
2226

2327
# Logprobs for this request
2428
logprobs: Optional[SampleLogprobs]
@@ -30,7 +34,7 @@ class LogprobsProcessor:
3034
@classmethod
3135
def from_new_request(
3236
cls,
33-
tokenizer: AnyTokenizer,
37+
tokenizer: Optional[AnyTokenizer],
3438
request: EngineCoreRequest,
3539
) -> "LogprobsProcessor":
3640
num_logprobs = request.sampling_params.logprobs
@@ -66,8 +70,8 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
6670
token_ids_lst):
6771

6872
# Detokenize (non-incrementally).
69-
decoded_tokens = convert_ids_list_to_tokens(
70-
self.tokenizer, token_ids)
73+
decoded_tokens = NONES if self.tokenizer is None else (
74+
convert_ids_list_to_tokens(self.tokenizer, token_ids))
7175

7276
# Sampler puts the sampled logprob in first.
7377
sampled_token_logprob = logprobs[0]
@@ -103,9 +107,9 @@ def _update_prompt_logprobs(
103107

104108
# Detokenize non-incrementally.
105109
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
106-
decoded_tokens = convert_ids_list_to_tokens(
107-
self.tokenizer,
108-
token_ids.flatten().tolist())
110+
decoded_tokens = None if self.tokenizer is None else (
111+
convert_ids_list_to_tokens(self.tokenizer,
112+
token_ids.flatten().tolist()))
109113

110114
# Recover shapes.
111115
num_prompt_tokens, num_logprobs = logprobs.shape
@@ -121,7 +125,8 @@ def _update_prompt_logprobs(
121125
# Handle flattening.
122126
offset = pos * num_logprobs
123127
offset_end = offset + num_logprobs
124-
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
128+
decoded_tokens_for_pos = NONES \
129+
if decoded_tokens is None else decoded_tokens[offset:offset_end]
125130

126131
# Update with the Logprob dictionary for this pos.
127132
self.prompt_logprobs.append(
@@ -153,7 +158,7 @@ def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
153158
def _make_logprob_dict(
154159
logprobs: list[float],
155160
logprob_token_ids: list[int],
156-
decoded_tokens: list[str],
161+
decoded_tokens: Iterable[Optional[str]],
157162
rank: int,
158163
num_logprobs: int,
159164
) -> dict[int, Logprob]:

vllm/v1/engine/output_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def from_new_request(
6868
queue: Optional[asyncio.Queue[RequestOutput]],
6969
log_stats: bool,
7070
) -> "RequestState":
71+
if not request.sampling_params.detokenize:
72+
tokenizer = None
7173
return cls(
7274
request_id=request.request_id,
7375
parent_req=parent_req,

0 commit comments

Comments
 (0)