Skip to content

Commit 9148095

Browse files
xingyuliucharlotte12l
authored andcommitted
initial
1 parent a742322 commit 9148095

File tree

12 files changed

+198
-2
lines changed

12 files changed

+198
-2
lines changed

vllm/outputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
*,
120120
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
121121
kv_transfer_params: Optional[dict[str, Any]] = None,
122+
prompt_hidden_states: Optional[torch.Tensor] = None,
122123
# Forward compatibility, code that uses args added in new release can
123124
# still run with older versions of vLLM without breaking.
124125
**kwargs: Any,
@@ -139,6 +140,7 @@ def __init__(
139140
self.encoder_prompt_token_ids = encoder_prompt_token_ids
140141
self.num_cached_tokens = num_cached_tokens
141142
self.kv_transfer_params = kv_transfer_params
143+
self.prompt_hidden_states = prompt_hidden_states
142144

143145
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
144146
"""Merge subsequent RequestOutput into this one"""

vllm/sampling_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class SamplingParams(
166166
response. When set to -1, return all `vocab_size` log probabilities."""
167167
prompt_logprobs: Optional[int] = None
168168
"""Number of log probabilities to return per prompt token."""
169+
return_prompt_hidden_states: bool = False
170+
169171
# NOTE: This parameter is only exposed at the engine level for now.
170172
# It is not exposed in the OpenAI API server, as the OpenAI API does
171173
# not support returning only a list of token IDs.

vllm/v1/core/sched/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ def update_from_output(
845845
sampled_token_ids = model_runner_output.sampled_token_ids
846846
logprobs = model_runner_output.logprobs
847847
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
848+
prompt_hidden_states_dict = model_runner_output.prompt_hidden_states_dict
848849
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
849850
pooler_outputs = model_runner_output.pooler_output
850851
num_nans_in_logits = model_runner_output.num_nans_in_logits
@@ -932,6 +933,7 @@ def update_from_output(
932933

933934
# Get prompt logprobs for this request.
934935
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
936+
prompt_hidden_states = prompt_hidden_states_dict.get(req_id)
935937
if new_token_ids or pooler_output is not None \
936938
or kv_transfer_params:
937939

@@ -943,6 +945,7 @@ def update_from_output(
943945
finish_reason=request.get_finished_reason(),
944946
new_logprobs=new_logprobs,
945947
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
948+
prompt_hidden_states=prompt_hidden_states,
946949
pooling_output=pooler_output,
947950
stop_reason=request.stop_reason,
948951
events=request.take_events(),

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class EngineCoreOutput(
104104
new_logprobs: Optional[LogprobsLists] = None
105105
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
106106

107+
prompt_hidden_states: Optional[torch.Tensor] = None
107108
pooling_output: Optional[torch.Tensor] = None
108109

109110
finish_reason: Optional[FinishReason] = None

vllm/v1/engine/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
291291
model_output = self.execute_model_with_error_logging(
292292
self.model_executor.execute_model, # type: ignore
293293
scheduler_output)
294+
print("lxy model_output to enginecoreoutput")
294295
engine_core_outputs = self.scheduler.update_from_output(
295296
scheduler_output, model_output) # type: ignore
296297

vllm/v1/engine/hidden_states.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import itertools
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
7+
import torch
8+
9+
from vllm.logger import init_logger
10+
from vllm.sequence import PromptLogprobs
11+
from vllm.v1.engine import EngineCoreOutput
12+
13+
logger = init_logger(__name__)
14+
15+
NONES = itertools.repeat(None)
16+
17+
18+
@dataclass
19+
class HiddenStatesProcessor:
20+
prompt_hidden_states: Optional[torch.Tensor]
21+
22+
@classmethod
23+
def from_new_request(cls, ) -> "HiddenStatesProcessor":
24+
return cls(prompt_hidden_states=None)
25+
26+
def _set_prompt_hidden_states(
27+
self,
28+
prompt_hidden_states_tensor: torch.Tensor,
29+
) -> None:
30+
"""Update with prompt logprobs from EngineCore.
31+
32+
Args:
33+
prompt_logprobs_tensors: tuple containing the prompt logprobs
34+
tensors.
35+
36+
"""
37+
38+
# We only need to set the prompt hidden states once.
39+
# TODO: check logprobs
40+
assert self.prompt_hidden_states is None
41+
42+
self.prompt_hidden_states = prompt_hidden_states_tensor
43+
44+
def pop_prompt_hidden_states(self) -> Optional[PromptLogprobs]:
45+
"""Pop and return all request prompt logprobs
46+
47+
The logprobs processor aggregates prompt chunk logprobs
48+
over one or more prefill chunks. This method returns
49+
all prompt logprobs at once and then forgets them.
50+
Ensures correct RequestOutputKind.DELTA semantics
51+
wherein all prompt logprobs are returned at once at
52+
the end of prefill.
53+
54+
Returns:
55+
None if prompt logprobs are disabled for this request.
56+
List of all prompt logprobs, otherwise.
57+
"""
58+
plp = self.prompt_hidden_states
59+
if plp:
60+
self.prompt_hidden_states = None
61+
return plp
62+
63+
def update_from_output(self, output: EngineCoreOutput) -> None:
64+
if output.prompt_hidden_states is not None:
65+
print("lxy update_from_output")
66+
self._set_prompt_hidden_states(output.prompt_hidden_states)

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
242242

243243
# 2) Process EngineCoreOutputs.
244244
iteration_stats = IterationStats() if self.log_stats else None
245+
print("lxy call process_outputs")
245246
processed_outputs = self.output_processor.process_outputs(
246247
outputs.outputs,
247248
engine_core_timestamp=outputs.timestamp,

vllm/v1/engine/output_processor.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
1616
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
1717
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
18+
from vllm.v1.engine.hidden_states import HiddenStatesProcessor
1819
from vllm.v1.engine.logprobs import LogprobsProcessor
1920
from vllm.v1.engine.parallel_sampling import ParentRequest
2021
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
@@ -93,6 +94,7 @@ def __init__(
9394
arrival_time: float,
9495
queue: Optional[RequestOutputCollector],
9596
log_stats: bool,
97+
hidden_states_processor: Optional[HiddenStatesProcessor],
9698
):
9799
self.request_id = request_id
98100
self.parent_req = parent_req
@@ -111,6 +113,7 @@ def __init__(
111113

112114
self.stats = RequestStateStats(
113115
arrival_time=arrival_time) if log_stats else None
116+
self.hidden_states_processor = hidden_states_processor
114117

115118
@classmethod
116119
def from_new_request(
@@ -137,10 +140,12 @@ def from_new_request(
137140
request=request,
138141
)
139142
max_tokens_param = sampling_params.max_tokens
143+
hidden_states_processor = HiddenStatesProcessor.from_new_request()
140144
else:
141145
logprobs_processor = None
142146
detokenizer = None
143147
max_tokens_param = None
148+
hidden_states_processor = None
144149
assert request.pooling_params is not None
145150
output_kind = request.pooling_params.output_kind
146151

@@ -159,6 +164,7 @@ def from_new_request(
159164
arrival_time=request.arrival_time,
160165
queue=queue,
161166
log_stats=log_stats,
167+
hidden_states_processor=hidden_states_processor,
162168
)
163169

164170
def make_request_output(
@@ -204,7 +210,7 @@ def _new_request_output(
204210
finished: bool,
205211
kv_transfer_params: Optional[dict[str, Any]] = None,
206212
) -> Union[RequestOutput, PoolingRequestOutput]:
207-
213+
# Seeems here to process outputs
208214
first_output = outputs[0]
209215
if isinstance(first_output, PoolingOutput):
210216
assert len(outputs) == 1
@@ -215,17 +221,23 @@ def _new_request_output(
215221
finished=finished,
216222
)
217223
assert self.logprobs_processor is not None
224+
assert self.hidden_states_processor is not None
218225
if self.output_kind == RequestOutputKind.DELTA:
219226
# Side effect: logprobs processor forgets prompt logprobs
220227
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
228+
prompt_hidden_states = self.hidden_states_processor.pop_prompt_hidden_states(
229+
)
221230
else:
222231
prompt_logprobs = self.logprobs_processor.prompt_logprobs
232+
prompt_hidden_states = self.hidden_states_processor.prompt_hidden_states
223233

234+
# prompt logprobs is added here
224235
return RequestOutput(
225236
request_id=request_id,
226237
prompt=self.prompt,
227238
prompt_token_ids=self.prompt_token_ids,
228239
prompt_logprobs=prompt_logprobs,
240+
prompt_hidden_states=prompt_hidden_states,
229241
outputs=cast(list[CompletionOutput], outputs),
230242
finished=finished,
231243
kv_transfer_params=kv_transfer_params,
@@ -399,6 +411,7 @@ def process_outputs(
399411
kv_transfer_params = engine_core_output.kv_transfer_params
400412
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
401413
req_state.is_prefilling = False
414+
prompt_hidden_states = engine_core_output.prompt_hidden_states
402415

403416
if pooling_output is None:
404417
assert req_state.detokenizer is not None
@@ -414,8 +427,12 @@ def process_outputs(
414427
# if required.
415428
req_state.logprobs_processor.update_from_output(
416429
engine_core_output)
430+
assert req_state.hidden_states_processor is not None
431+
req_state.hidden_states_processor.update_from_output(
432+
engine_core_output)
417433

418434
# 4) Create and handle RequestOutput objects.
435+
print("lxy here make_request_output", prompt_hidden_states is None)
419436
if request_output := req_state.make_request_output(
420437
new_token_ids, pooling_output, finish_reason, stop_reason,
421438
kv_transfer_params):

vllm/v1/outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class ModelRunnerOutput:
105105
# [prompt_len]
106106
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
107107

108+
# req_id ->
109+
prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]]
110+
108111
# [num_reqs, hidden_size]
109112
pooler_output: list[Optional[torch.Tensor]]
110113

@@ -128,5 +131,6 @@ class DraftTokenIds:
128131
sampled_token_ids=[],
129132
logprobs=None,
130133
prompt_logprobs_dict={},
134+
prompt_hidden_states_dict={},
131135
pooler_output=[],
132136
num_nans_in_logits=None)

vllm/v1/worker/gpu_input_batch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ def __init__(
217217
# NOTE(rob): num_prompt_logprobs only includes reqs
218218
# that are currently in the prefill phase.
219219
self.num_prompt_logprobs: dict[str, int] = {}
220-
221220
# To accumulate prompt logprobs tensor chunks across prefill steps.
222221
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
223222

223+
self.return_prompt_hidden_states_reqs: set[str] = set()
224+
self.in_progress_prompt_hidden_states_cpu: dict[str, torch.Tensor] = {}
225+
224226
# Internal representation of per-step batch state changes, used for
225227
# reordering persistent batch and generating logitsprocs batch state
226228
# updates. Should reset each step.
@@ -358,6 +360,9 @@ def add_request(
358360
self.num_prompt_logprobs[
359361
req_id] = sampling_params.prompt_logprobs
360362

363+
if sampling_params.return_prompt_hidden_states:
364+
self.return_prompt_hidden_states_reqs.add(req_id)
365+
361366
if sampling_params.allowed_token_ids:
362367
self.has_allowed_token_ids.add(req_id)
363368
if self.allowed_token_ids_mask_cpu_tensor is None:
@@ -447,6 +452,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
447452
self.num_logprobs.pop(req_id, None)
448453
self.num_prompt_logprobs.pop(req_id, None)
449454
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
455+
self.in_progress_prompt_hidden_states_cpu.pop(req_id, None)
450456

451457
self.has_allowed_token_ids.discard(req_id)
452458
if self.allowed_token_ids_mask_cpu_tensor is not None:

0 commit comments

Comments
 (0)