From 79f61e1b9dcb5e9e819e741b5b6d39916a7b2f81 Mon Sep 17 00:00:00 2001 From: Himanshu Jaju Date: Tue, 4 Mar 2025 16:56:13 +0000 Subject: [PATCH 1/5] Follow detokenize sampling param Signed-off-by: Himanshu Jaju --- vllm/v1/engine/output_processor.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 75c638a854f8..5bfeefd868ad 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -36,6 +36,7 @@ def __init__( prompt_token_ids: list[int], logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, + detokenize: bool, max_tokens_param: Optional[int], arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], @@ -51,6 +52,7 @@ def __init__( self.prompt_len = len(prompt_token_ids) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer + self.detokenize = detokenize self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue @@ -85,6 +87,7 @@ def from_new_request( tokenizer=tokenizer, request=request, ), + detokenize=request.sampling_params.detokenize, max_tokens_param=(request.sampling_params.max_tokens if request.sampling_params is not None else None), arrival_time=request.arrival_time, @@ -156,7 +159,7 @@ def _new_completion_output( delta = self.output_kind == RequestOutputKind.DELTA # Prepare text and token_ids, based on delta mode - text = self.detokenizer.get_next_output_text(finished, delta) + text = self.detokenizer.get_next_output_text(finished, delta) if self.detokenize else "" if not delta: token_ids = self.detokenizer.output_token_ids @@ -290,10 +293,11 @@ def process_outputs( # 2) Detokenize the token ids into text and check for stop # strings. - stop_string = req_state.detokenizer.update(new_token_ids) - if stop_string and finish_reason != FinishReason.STOP: - finish_reason = FinishReason.STOP - stop_reason = stop_string + if req_state.detokenize: + stop_string = req_state.detokenizer.update(new_token_ids) + if stop_string and finish_reason != FinishReason.STOP: + finish_reason = FinishReason.STOP + stop_reason = stop_string # 3) Compute sample and prompt logprobs for request, # if required. From 4f9fce15c054751ea2d65599e53f87c87be22e16 Mon Sep 17 00:00:00 2001 From: Himanshu Jaju Date: Tue, 4 Mar 2025 17:02:46 +0000 Subject: [PATCH 2/5] precommit Signed-off-by: Himanshu Jaju --- vllm/v1/engine/output_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5bfeefd868ad..35156e40eae5 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -159,7 +159,8 @@ def _new_completion_output( delta = self.output_kind == RequestOutputKind.DELTA # Prepare text and token_ids, based on delta mode - text = self.detokenizer.get_next_output_text(finished, delta) if self.detokenize else "" + text = self.detokenizer.get_next_output_text( + finished, delta) if self.detokenize else "" if not delta: token_ids = self.detokenizer.output_token_ids From bbf5e7904a77ee1bef1de03ca5f831a0a99671e4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 5 Mar 2025 15:20:37 -0800 Subject: [PATCH 3/5] Move logic inside detokenizer; also apply to logprobs Signed-off-by: Nick Hill --- tests/v1/sample/test_sampling_params_e2e.py | 24 +++++++++++++ vllm/v1/engine/detokenizer.py | 40 ++++++++++++--------- vllm/v1/engine/logprobs.py | 25 +++++++------ vllm/v1/engine/output_processor.py | 17 ++++----- 4 files changed, 70 insertions(+), 36 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 4e88feae44dd..54b31fcb1ee0 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -79,9 +79,33 @@ def test_stop_token_ids(model): stop_token_ids = [stop_token_id_0, stop_token_id_1] params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) + output = model.generate(PROMPT, params) assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 +def test_detokenize_false(model): + """Check that detokenize=False option works.""" + + output = model.generate(PROMPT, SamplingParams(detokenize=False)) + assert len(output[0].outputs[0].token_ids) > 0 + assert len(output[0].outputs[0].text) == 0 + + output = model.generate( + PROMPT, SamplingParams(detokenize=False, logprobs=3, + prompt_logprobs=3)) + assert len(output[0].outputs[0].token_ids) > 0 + assert len(output[0].outputs[0].text) == 0 + + prompt_logprobs = output[0].prompt_logprobs + sampled_logprobs = output[0].outputs[0].logprobs + assert len(prompt_logprobs) > 0 + assert len(sampled_logprobs) > 0 + for all_logprobs in (prompt_logprobs, sampled_logprobs): + for logprobs in all_logprobs: + assert 3 <= len(logprobs) <= 4 + assert all(lp.decoded_token is None for lp in logprobs.values()) + + def test_bad_words(model): """Check that we respect bad words.""" diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 4a1636f49495..92754920b62d 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from vllm.engine.output_processor.stop_checker import StopChecker @@ -16,41 +16,46 @@ class IncrementalDetokenizer: # Generation data - output_text: str - tokens: list[str] token_ids: list[int] - prompt_len: int + output_text: str = "" + tokens: list[str] = field(default_factory=list) + prompt_len: int = 0 # Stop strings - stop: list[str] - include_stop_str_in_output: bool + stop: list[str] = field(default_factory=list) + include_stop_str_in_output: bool = False # Metadata for incremental detokenization - prefix_offset: int - read_offset: int + prefix_offset: int = 0 + read_offset: int = 0 # Parameters for detokenization - skip_special_tokens: bool - spaces_between_special_tokens: bool + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True - # Tokenizer for this request - tokenizer: AnyTokenizer + # Tokenizer for this request, + # None if detokenization is disabled. + tokenizer: Optional[AnyTokenizer] = None # Accounting for stop string buffering - stop_buffer_length: int + stop_buffer_length: int = 0 _last_output_text_offset: int = 0 @property def output_token_ids(self) -> list[int]: - return self.token_ids[self.prompt_len:] + return self.token_ids if not self.prompt_len else ( + self.token_ids[self.prompt_len:]) @classmethod def from_new_request( cls, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "IncrementalDetokenizer": + if tokenizer is None: + return cls(token_ids=[]) + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, @@ -66,7 +71,6 @@ def from_new_request( stop_buffer_length = 0 return cls( - output_text="", tokens=tokens, # Detokenizer mutates this list, so need a unique copy. # NOTE(Nick): could we take ownership of it though? @@ -93,6 +97,10 @@ def update(self, new_token_ids: list[int]) -> Optional[str]: Return matched stop string or None. """ + if self.tokenizer is None: + self.token_ids.extend(new_token_ids) + return None + # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 7f572163ead4..500de14e57d6 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from collections.abc import Iterable from dataclasses import dataclass from typing import Optional @@ -13,12 +14,15 @@ logger = init_logger(__name__) +NONES = itertools.repeat(None) + @dataclass class LogprobsProcessor: - # Tokenizer for this request - tokenizer: AnyTokenizer + # Tokenizer for this request, + # None if detokenization is disabled. + tokenizer: Optional[AnyTokenizer] # Logprobs for this request logprobs: Optional[SampleLogprobs] @@ -30,7 +34,7 @@ class LogprobsProcessor: @classmethod def from_new_request( cls, - tokenizer: AnyTokenizer, + tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "LogprobsProcessor": num_logprobs = request.sampling_params.logprobs @@ -66,8 +70,8 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: token_ids_lst): # Detokenize (non-incrementally). - decoded_tokens = convert_ids_list_to_tokens( - self.tokenizer, token_ids) + decoded_tokens = NONES if self.tokenizer is None else ( + convert_ids_list_to_tokens(self.tokenizer, token_ids)) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] @@ -103,9 +107,9 @@ def _update_prompt_logprobs( # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - decoded_tokens = convert_ids_list_to_tokens( - self.tokenizer, - token_ids.flatten().tolist()) + decoded_tokens = None if self.tokenizer is None else ( + convert_ids_list_to_tokens(self.tokenizer, + token_ids.flatten().tolist())) # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape @@ -121,7 +125,8 @@ def _update_prompt_logprobs( # Handle flattening. offset = pos * num_logprobs offset_end = offset + num_logprobs - decoded_tokens_for_pos = decoded_tokens[offset:offset_end] + decoded_tokens_for_pos = NONES \ + if decoded_tokens is None else decoded_tokens[offset:offset_end] # Update with the Logprob dictionary for this pos. self.prompt_logprobs.append( @@ -153,7 +158,7 @@ def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: def _make_logprob_dict( logprobs: list[float], logprob_token_ids: list[int], - decoded_tokens: list[str], + decoded_tokens: Iterable[Optional[str]], rank: int, num_logprobs: int, ) -> dict[int, Logprob]: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 35156e40eae5..aca0233e416b 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -36,7 +36,6 @@ def __init__( prompt_token_ids: list[int], logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, - detokenize: bool, max_tokens_param: Optional[int], arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], @@ -52,7 +51,6 @@ def __init__( self.prompt_len = len(prompt_token_ids) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer - self.detokenize = detokenize self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue @@ -70,6 +68,8 @@ def from_new_request( queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, ) -> "RequestState": + if not request.sampling_params.detokenize: + tokenizer = None return cls( request_id=request.request_id, parent_req=parent_req, @@ -87,7 +87,6 @@ def from_new_request( tokenizer=tokenizer, request=request, ), - detokenize=request.sampling_params.detokenize, max_tokens_param=(request.sampling_params.max_tokens if request.sampling_params is not None else None), arrival_time=request.arrival_time, @@ -159,8 +158,7 @@ def _new_completion_output( delta = self.output_kind == RequestOutputKind.DELTA # Prepare text and token_ids, based on delta mode - text = self.detokenizer.get_next_output_text( - finished, delta) if self.detokenize else "" + text = self.detokenizer.get_next_output_text(finished, delta) if not delta: token_ids = self.detokenizer.output_token_ids @@ -294,11 +292,10 @@ def process_outputs( # 2) Detokenize the token ids into text and check for stop # strings. - if req_state.detokenize: - stop_string = req_state.detokenizer.update(new_token_ids) - if stop_string and finish_reason != FinishReason.STOP: - finish_reason = FinishReason.STOP - stop_reason = stop_string + stop_string = req_state.detokenizer.update(new_token_ids) + if stop_string and finish_reason != FinishReason.STOP: + finish_reason = FinishReason.STOP + stop_reason = stop_string # 3) Compute sample and prompt logprobs for request, # if required. From e70dd32469975ca71f1e0ca6f371ebb102b6a0fd Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 5 Mar 2025 22:53:48 -0800 Subject: [PATCH 4/5] disable APC in test to allow prompt logprobs Signed-off-by: Nick Hill --- tests/v1/sample/test_sampling_params_e2e.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 54b31fcb1ee0..ae30381e0e30 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -14,7 +14,10 @@ @pytest.fixture(scope="module") def model() -> LLM: - return LLM(MODEL, enforce_eager=True) + # Disable prefix caching so that we can test prompt logprobs. + # TODO remove this after https://github.com/vllm-project/vllm/pull/13949 + # is merged + return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False) def test_n_gt_1(model): From a6aa82f582a4a4497a4eec62dd1dcf06f9219dae Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 6 Mar 2025 07:58:48 -0800 Subject: [PATCH 5/5] Fix unit test Signed-off-by: Nick Hill --- tests/v1/sample/test_sampling_params_e2e.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index ae30381e0e30..4090927f378e 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -101,9 +101,9 @@ def test_detokenize_false(model): prompt_logprobs = output[0].prompt_logprobs sampled_logprobs = output[0].outputs[0].logprobs - assert len(prompt_logprobs) > 0 - assert len(sampled_logprobs) > 0 - for all_logprobs in (prompt_logprobs, sampled_logprobs): + assert len(prompt_logprobs) > 1 + assert len(sampled_logprobs) > 1 + for all_logprobs in (prompt_logprobs[1:], sampled_logprobs): for logprobs in all_logprobs: assert 3 <= len(logprobs) <= 4 assert all(lp.decoded_token is None for lp in logprobs.values())