@@ -27,33 +27,24 @@ def _set_prompt_hidden_states(
2727 self ,
2828 prompt_hidden_states_tensor : torch .Tensor ,
2929 ) -> None :
30- """Update with prompt logprobs from EngineCore.
31-
32- Args:
33- prompt_logprobs_tensors: tuple containing the prompt logprobs
34- tensors.
35-
36- """
37-
3830 # We only need to set the prompt hidden states once.
39- # TODO: check logprobs
4031 assert self .prompt_hidden_states is None
4132
4233 self .prompt_hidden_states = prompt_hidden_states_tensor
4334
4435 def pop_prompt_hidden_states (self ) -> Optional [PromptLogprobs ]:
45- """Pop and return all request prompt logprobs
36+ """Pop and return all request prompt hidden states
4637
47- The logprobs processor aggregates prompt chunk logprobs
38+ The hidden states processor aggregates prompt chunk hidden states
4839 over one or more prefill chunks. This method returns
49- all prompt logprobs at once and then forgets them.
40+ all prompt hidden states at once and then forgets them.
5041 Ensures correct RequestOutputKind.DELTA semantics
51- wherein all prompt logprobs are returned at once at
42+ wherein all prompt hidden states are returned at once at
5243 the end of prefill.
5344
5445 Returns:
55- None if prompt logprobs are disabled for this request.
56- List of all prompt logprobs , otherwise.
46+ None if prompt hidden states are disabled for this request.
47+ List of all prompt hidden states , otherwise.
5748 """
5849 plp = self .prompt_hidden_states
5950 if plp :
@@ -62,5 +53,4 @@ def pop_prompt_hidden_states(self) -> Optional[PromptLogprobs]:
6253
6354 def update_from_output (self , output : EngineCoreOutput ) -> None :
6455 if output .prompt_hidden_states is not None :
65- print ("lxy update_from_output" )
6656 self ._set_prompt_hidden_states (output .prompt_hidden_states )
0 commit comments