diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 16ceddf13511..69594ec72ab1 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -69,6 +69,8 @@ async def beam_search( ignore_eos = params.ignore_eos temperature = params.temperature length_penalty = params.length_penalty + stop = params.stop + stop_token_ids = params.stop_token_ids tokenizer = await self.get_tokenizer(lora_request=None) tokenizedPrompt = prompt if isinstance( @@ -80,7 +82,10 @@ async def beam_search( beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, - temperature=temperature) + temperature=temperature, + ignore_eos=ignore_eos, + stop=stop, + stop_token_ids=stop_token_ids) all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] completed = [] @@ -117,8 +122,13 @@ async def beam_search( cum_logprob=current_beam.cum_logprob + logprob_obj.logprob) - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: + if ( + result.outputs[0].finish_reason == "stop" or + ( + token_id == tokenizer.eos_token_id + and not ignore_eos + ) + ): completed.append(new_beam) else: new_beams.append(new_beam) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2010381076c7..86dfe538774d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -375,6 +375,8 @@ def beam_search( temperature = params.temperature ignore_eos = params.ignore_eos length_penalty = params.length_penalty + stop = params.stop + stop_token_ids = params.stop_token_ids def sort_beams_key(x: BeamSearchSequence) -> float: return get_beam_search_score(x.tokens, x.cum_logprob, @@ -387,7 +389,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float: # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, - temperature=temperature) + temperature=temperature, + ignore_eos=ignore_eos, + stop=stop, + stop_token_ids=stop_token_ids) instances: List[BeamSearchInstance] = [] for prompt in prompts: @@ -436,8 +441,13 @@ def sort_beams_key(x: BeamSearchSequence) -> float: cum_logprob=current_beam.cum_logprob + logprob_obj.logprob) - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: + if ( + result.outputs[0].finish_reason == "stop" or + ( + token_id == tokenizer.eos_token_id and + not ignore_eos + ) + ): instance.completed.append(new_beam) else: instance_new_beams.append(new_beam) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6f1135f8093b..7c25e0affe33 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -594,6 +594,8 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, + stop=self.stop, + stop_token_ids=self.stop_token_ids ) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4f2ae75e65f3..d3865188fcb4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -489,3 +489,5 @@ class BeamSearchParams( ignore_eos: bool = False temperature: float = 0.0 length_penalty: float = 1.0 + stop: Optional[Union[str, List[str]]] = None + stop_token_ids: Optional[List[int]] = None