Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ async def beam_search(
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
stop = params.stop
stop_token_ids = params.stop_token_ids

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
Expand All @@ -80,7 +82,10 @@ async def beam_search(

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
temperature=temperature,
ignore_eos=ignore_eos,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and below, not sure if informing the beam_search_params of ignore_eos does anything, since it did not produce the "stop" finish reason as I expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok after some more testing this is actually necessary -- otherwise the eos ignoring is not happening fully.

stop=stop,
stop_token_ids=stop_token_ids)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

Expand Down Expand Up @@ -117,8 +122,13 @@ async def beam_search(
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
if (
result.outputs[0].finish_reason == "stop" or
(
token_id == tokenizer.eos_token_id
and not ignore_eos
)
):
completed.append(new_beam)
else:
new_beams.append(new_beam)
Expand Down
16 changes: 13 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def beam_search(
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
stop = params.stop
stop_token_ids = params.stop_token_ids

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
Expand All @@ -387,7 +389,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
temperature=temperature,
ignore_eos=ignore_eos,
stop=stop,
stop_token_ids=stop_token_ids)
instances: List[BeamSearchInstance] = []

for prompt in prompts:
Expand Down Expand Up @@ -436,8 +441,13 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
if (
result.outputs[0].finish_reason == "stop" or
(
token_id == tokenizer.eos_token_id and
not ignore_eos
)
):
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
stop=self.stop,
stop_token_ids=self.stop_token_ids
)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
Expand Down
2 changes: 2 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,5 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None