diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 4e88feae44dd..4090927f378e 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -14,7 +14,10 @@ @pytest.fixture(scope="module") def model() -> LLM: - return LLM(MODEL, enforce_eager=True) + # Disable prefix caching so that we can test prompt logprobs. + # TODO remove this after https://github.com/vllm-project/vllm/pull/13949 + # is merged + return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False) def test_n_gt_1(model): @@ -79,9 +82,33 @@ def test_stop_token_ids(model): stop_token_ids = [stop_token_id_0, stop_token_id_1] params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) + output = model.generate(PROMPT, params) assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 +def test_detokenize_false(model): + """Check that detokenize=False option works.""" + + output = model.generate(PROMPT, SamplingParams(detokenize=False)) + assert len(output[0].outputs[0].token_ids) > 0 + assert len(output[0].outputs[0].text) == 0 + + output = model.generate( + PROMPT, SamplingParams(detokenize=False, logprobs=3, + prompt_logprobs=3)) + assert len(output[0].outputs[0].token_ids) > 0 + assert len(output[0].outputs[0].text) == 0 + + prompt_logprobs = output[0].prompt_logprobs + sampled_logprobs = output[0].outputs[0].logprobs + assert len(prompt_logprobs) > 1 + assert len(sampled_logprobs) > 1 + for all_logprobs in (prompt_logprobs[1:], sampled_logprobs): + for logprobs in all_logprobs: + assert 3 <= len(logprobs) <= 4 + assert all(lp.decoded_token is None for lp in logprobs.values()) + + def test_bad_words(model): """Check that we respect bad words.""" diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 4a1636f49495..92754920b62d 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from vllm.engine.output_processor.stop_checker import StopChecker @@ -16,41 +16,46 @@ class IncrementalDetokenizer: # Generation data - output_text: str - tokens: list[str] token_ids: list[int] - prompt_len: int + output_text: str = "" + tokens: list[str] = field(default_factory=list) + prompt_len: int = 0 # Stop strings - stop: list[str] - include_stop_str_in_output: bool + stop: list[str] = field(default_factory=list) + include_stop_str_in_output: bool = False # Metadata for incremental detokenization - prefix_offset: int - read_offset: int + prefix_offset: int = 0 + read_offset: int = 0 # Parameters for detokenization - skip_special_tokens: bool - spaces_between_special_tokens: bool + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True - # Tokenizer for this request - tokenizer: AnyTokenizer + # Tokenizer for this request, + # None if detokenization is disabled. + tokenizer: Optional[AnyTokenizer] = None # Accounting for stop string buffering - stop_buffer_length: int + stop_buffer_length: int = 0 _last_output_text_offset: int = 0 @property def output_token_ids(self) -> list[int]: - return self.token_ids[self.prompt_len:] + return self.token_ids if not self.prompt_len else ( + self.token_ids[self.prompt_len:]) @classmethod def from_new_request( cls, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "IncrementalDetokenizer": + if tokenizer is None: + return cls(token_ids=[]) + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, @@ -66,7 +71,6 @@ def from_new_request( stop_buffer_length = 0 return cls( - output_text="", tokens=tokens, # Detokenizer mutates this list, so need a unique copy. # NOTE(Nick): could we take ownership of it though? @@ -93,6 +97,10 @@ def update(self, new_token_ids: list[int]) -> Optional[str]: Return matched stop string or None. """ + if self.tokenizer is None: + self.token_ids.extend(new_token_ids) + return None + # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 7f572163ead4..500de14e57d6 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from collections.abc import Iterable from dataclasses import dataclass from typing import Optional @@ -13,12 +14,15 @@ logger = init_logger(__name__) +NONES = itertools.repeat(None) + @dataclass class LogprobsProcessor: - # Tokenizer for this request - tokenizer: AnyTokenizer + # Tokenizer for this request, + # None if detokenization is disabled. + tokenizer: Optional[AnyTokenizer] # Logprobs for this request logprobs: Optional[SampleLogprobs] @@ -30,7 +34,7 @@ class LogprobsProcessor: @classmethod def from_new_request( cls, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "LogprobsProcessor": num_logprobs = request.sampling_params.logprobs @@ -66,8 +70,8 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: token_ids_lst): # Detokenize (non-incrementally). - decoded_tokens = convert_ids_list_to_tokens( - self.tokenizer, token_ids) + decoded_tokens = NONES if self.tokenizer is None else ( + convert_ids_list_to_tokens(self.tokenizer, token_ids)) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] @@ -103,9 +107,9 @@ def _update_prompt_logprobs( # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - decoded_tokens = convert_ids_list_to_tokens( - self.tokenizer, - token_ids.flatten().tolist()) + decoded_tokens = None if self.tokenizer is None else ( + convert_ids_list_to_tokens(self.tokenizer, + token_ids.flatten().tolist())) # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape @@ -121,7 +125,8 @@ def _update_prompt_logprobs( # Handle flattening. offset = pos * num_logprobs offset_end = offset + num_logprobs - decoded_tokens_for_pos = decoded_tokens[offset:offset_end] + decoded_tokens_for_pos = NONES \ + if decoded_tokens is None else decoded_tokens[offset:offset_end] # Update with the Logprob dictionary for this pos. self.prompt_logprobs.append( @@ -153,7 +158,7 @@ def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: def _make_logprob_dict( logprobs: list[float], logprob_token_ids: list[int], - decoded_tokens: list[str], + decoded_tokens: Iterable[Optional[str]], rank: int, num_logprobs: int, ) -> dict[int, Logprob]: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 75c638a854f8..aca0233e416b 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -68,6 +68,8 @@ def from_new_request( queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, ) -> "RequestState": + if not request.sampling_params.detokenize: + tokenizer = None return cls( request_id=request.request_id, parent_req=parent_req,