diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 7203d635c2fa..a084a6a7b0cf 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -164,7 +164,7 @@ async def test_multi_step_pp_smoke( Args: tp_size: degree of tensor-parallelism pp_size: degree of pipeline-parallelism - eager_mode + monkeypatch: fixture which we use to temporarily override backend env var """ model = "JackFram/llama-160m" @@ -223,3 +223,134 @@ async def test_multi_step_pp_smoke( test_generations = get_client_text_generations(test_completions) assert ref_generations == test_generations + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("pp_size", [1]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("max_output_len", [7]) +@pytest.mark.parametrize("n,best_of", [ + (1, 3), + (2, 2), + (2, 3), +]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) +@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.asyncio +async def test_multi_step_llm_best_of_fallback_async( + monkeypatch, + example_prompts, + model: str, + tp_size: int, + pp_size: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + max_output_len: int, + n: int, + best_of: int, + attention_backend: str, + is_async: bool, + num_logprobs: Optional[int], +) -> None: + """Test vLLM server with multi-step & best_of > 1 + + Currently multi-step scheduling does not support best_of > 1 or + beam search, + however the default behavior is for the engine to fall back + on single-step + scheduling rather than failing. + + Args: + monkeypatch: fixture which we use to temporarily override backend env var + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + tp_size: degree of tensor-parallelism + pp_size: degree of pipeline-parallelism + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + max_output_len + n: num seqs to output per :class:`SequenceGroup` + best_of: num seqs per :class:`SequenceGroup` from which to choose + attention_backend + is_async: if True, use async output processor + num_logprobs: number of logprobs to return per token + """ + + override_backend_env_variable(monkeypatch, attention_backend) + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] + ms_server_args = DEFAULT_SERVER_ARGS + \ + ["--num-scheduler-steps", f"{num_scheduler_steps}"] + + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + + if enforce_eager: + ms_server_args.append("--enforce-eager") + + distributed_args = [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + ] + + # Requests will share a random seed + seed = 42 + + # Spin up client/server & issue completion API requests. + # Default `max_wait_seconds` is 240 but was empirically + # was raised 3x to 720 *just for this test* due to + # observed timeouts in GHA CI + ref_completions = await completions_with_server_args( + prompts, + model, + server_args + distributed_args, + num_logprobs, + max_wait_seconds=5 * 240, + best_of=best_of, + n=n, + max_tokens=max_output_len, + temperature=1.0, + seed=seed) + test_completions = await completions_with_server_args( + prompts, + model, + ms_server_args + distributed_args, + num_logprobs, + max_wait_seconds=5 * 240, + best_of=best_of, + n=n, + max_tokens=max_output_len, + temperature=1.0, + seed=seed) + + # Assert multi-step scheduling produces identical tokens + # to single-step scheduling. + ref_generations = get_client_text_generations(ref_completions) + test_generations = get_client_text_generations(test_completions) + assert ref_generations == test_generations + + # Assert multi-step scheduling produces nearly-identical logprobs + # to single-step scheduling. + ref_text_logprobs = get_client_text_logprob_generations(ref_completions) + test_text_logprobs = get_client_text_logprob_generations(test_completions) + check_logprobs_close( + outputs_0_lst=ref_text_logprobs, + outputs_1_lst=test_text_logprobs, + name_0="single-step", + name_1="multi-step", + ) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index cc1fd1925201..0dab97cf378b 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -5,6 +5,9 @@ import pytest +from vllm import SamplingParams +from vllm.entrypoints.utils import STR_MULTI_STEP_BEAM_SEARCH_NOT_SUPPORTED + from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ @@ -192,11 +195,173 @@ def test_multi_step_llm_w_prompt_logprobs( check_logprobs_close( outputs_0_lst=single_step_vllm_outputs, outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", + name_0="single_step_vllm", + name_1="multi_step_vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("max_output_len", [7]) +@pytest.mark.parametrize("n,best_of", [ + (1, 2), + (2, 2), + (2, 3), +]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +def test_multi_step_llm_best_of_fallback( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + max_output_len: int, + n: int, + best_of: int, + enable_chunked_prefill: bool, + enable_prefix_caching: bool, +) -> None: + """Test vLLM engine with multi-step & best_of > 1 + + Currently multi-step scheduling does not support best_of > 1 or beam search, + however the default behavior is for the engine to fall back on single-step + scheduling rather than failing. + + Two instantiations of the sync vLLM engine are tested, one with single-step + and one with multi-step scheduling. + + Each instantiation of vLLM is tested in 3 phases: + 1. Batch of requests without best_of > 1 + 2. Batch of requests with best_of > 1 + 3. Batch of requests without best_of > 1 + + For the instantiation of vLLM with multi-step scheduling, Phase 1 should use + multi-step scheduling, Phase 2 should fall back on single-step scheduling, + and Phase 3 should resume multi-step scheduling. + + The other instantiation should use single-step scheduling for all phases. + + Args: + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + max_output_len: the maximum number of tokens to generate + n: num seqs to output per :class:`SequenceGroup` + best_of: num seqs per :class:`SequenceGroup` from which to choose + enable_chunked_prefill + enable_prefix_caching + """ + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + # Sampling parameters with best_of > 1 which should trigger a + # multi-step scheduler to fall back on single-step scheduling + sampling_params_best_of_gt_1 = SamplingParams( + max_tokens=max_output_len, + ignore_eos=True, + temperature=1.0, + n=n, + best_of=best_of, + seed=42, + ) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=1, + enable_chunked_prefill=enable_chunked_prefill, + enable_prefix_caching=enable_prefix_caching, + ) as vllm_model: + outputs_ss_best_of_gt_1 = vllm_model.generate( + prompts, sampling_params_best_of_gt_1) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + enable_chunked_prefill=enable_chunked_prefill, + enable_prefix_caching=enable_prefix_caching, + ) as vllm_model: + outputs_ms_best_of_gt_1 = (vllm_model.generate( + prompts, sampling_params_best_of_gt_1)) + + check_outputs_equal( + outputs_0_lst=outputs_ss_best_of_gt_1, + outputs_1_lst=outputs_ms_best_of_gt_1, + name_0="outputs_ss_best_of_gt_1", + name_1="outputs_ms_best_of_gt_1", ) +@pytest.mark.parametrize("model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("num_scheduler_steps", [8]) +@pytest.mark.parametrize("max_output_len", [7]) +def test_multi_step_beam_search_fail( + vllm_runner, + example_prompts, + model: str, + dtype: str, + enforce_eager: int, + num_scheduler_steps: int, + max_output_len: int, +) -> None: + """Test that vLLM engine with multi-step fails if beam search is enabled. + + Beam search is not supported with multi-step. + + Args: + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + max_output_len + """ + + with pytest.raises(ValueError, + match=STR_MULTI_STEP_BEAM_SEARCH_NOT_SUPPORTED), \ + vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=1, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_model.generate_beam_search(example_prompts, 2, max_output_len) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) diff --git a/tests/utils.py b/tests/utils.py index e983104e3cb0..a248a17c4a5e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,7 @@ import openai import pytest import requests +from openai._types import NOT_GIVEN, NotGiven from openai.types.completion import Completion from typing_extensions import ParamSpec, assert_never @@ -615,6 +616,10 @@ async def completions_with_server_args( num_logprobs: Optional[int], max_wait_seconds: int = 240, max_tokens: Union[int, list] = 5, + best_of: Union[int, NotGiven] = NOT_GIVEN, + n: Union[int, NotGiven] = NOT_GIVEN, + temperature: Union[float, NotGiven] = 0, + seed: Union[int, NotGiven] = NOT_GIVEN, ) -> List[Completion]: '''Construct a remote OpenAI server, obtain an async client to the server & invoke the completions API to obtain completions. @@ -647,10 +652,13 @@ async def completions_with_server_args( client = server.get_async_client() outputs = [ client.completions.create(model=model_name, prompt=[p], - temperature=0, + temperature=temperature, stream=False, max_tokens=max_tok, - logprobs=num_logprobs) \ + logprobs=num_logprobs, + best_of=best_of, + n=n, + seed=seed) \ for p, max_tok in zip(prompts, max_tokens) ] outputs = await asyncio.gather(*outputs) @@ -663,8 +671,7 @@ def get_client_text_generations(completions: List[Completion]) -> List[str]: '''Extract generated tokens from the output of a request made to an Open-AI-protocol completions endpoint. ''' - assert all([len(x.choices) == 1 for x in completions]) - return [x.choices[0].text for x in completions] + return [c.text for x in completions for c in x.choices] def get_client_text_logprob_generations( diff --git a/vllm/config.py b/vllm/config.py index 25f841231ded..54871bfe33b7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1071,6 +1071,13 @@ def __init__(self, self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data self.policy = policy + + # `engine_permits_multi_step_scheduling` reflects the user-specified + # multi-step config. `current_step_is_multi_step` may be modified to + # override `engine_permits_multi_step_scheduling` in any given call to + # `schedule()` + self.current_step_is_multi_step = ( + self.engine_permits_multi_step_scheduling) self._verify_args() def _verify_args(self) -> None: @@ -1103,7 +1110,12 @@ def _verify_args(self) -> None: "equal to 1.") @property - def is_multi_step(self) -> bool: + def engine_permits_multi_step_scheduling(self) -> bool: + """Base multi-step setting, configured by user. + + Can be overridden by scheduler if multi-step scheduling is not supported + in the given request. + """ return self.num_scheduler_steps > 1 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 88733b8f53b8..1844c48c86be 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -395,6 +395,12 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] + # Multi-step scheduling is not supported with best_of > 1, + # so if multi-step is enabled, count the number of unfinished + # requests incompatible with multi-step & only allow a given + # step to use multi-step if the count is zero. + self._multi_step_incompat_req_count = 0 + @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -408,9 +414,65 @@ def num_decoding_tokens_per_seq(self) -> int: """The number of new tokens.""" return 1 + @property + def _current_step_is_multi_step(self) -> bool: + return self.scheduler_config.current_step_is_multi_step + + @_current_step_is_multi_step.setter + def _current_step_is_multi_step(self, value) -> None: + self.scheduler_config.current_step_is_multi_step = value + + @property + def _engine_permits_multi_step_scheduling(self) -> bool: + return self.scheduler_config.engine_permits_multi_step_scheduling + + def _seq_group_has_multi_step_incompat_sample_params( + self, + seq_group: SequenceGroup, + ) -> bool: + """:class:`SequenceGroup`s with best_of>1 are incompat. w/ multi-step""" + return (seq_group.sampling_params is not None + and seq_group.sampling_params.best_of is not None + and seq_group.sampling_params.best_of > 1) + + def _maybe_record_new_sg_w_multi_step_incompat_sample_params( + self, + seq_group: SequenceGroup, + ) -> None: + """If new req has best_of>1 & engine supports multistep, ++count""" + if (self._engine_permits_multi_step_scheduling + and self._seq_group_has_multi_step_incompat_sample_params( + seq_group)): + self._multi_step_incompat_req_count += 1 + + def _maybe_record_finished_sg_w_multi_step_incompat_sample_params( + self, + seq_group: SequenceGroup, + ) -> None: + """If finished req has best_of>1 & engine supports multistep, --count""" + if (self._engine_permits_multi_step_scheduling + and self._seq_group_has_multi_step_incompat_sample_params( + seq_group)): + assert self._multi_step_incompat_req_count > 0 + self._multi_step_incompat_req_count -= 1 + + def _maybe_disable_multi_step_by_sampling_params(self) -> None: + """Disable multi-step unless engine & all unfinished reqs support it""" + self._current_step_is_multi_step = ( + self._engine_permits_multi_step_scheduling + and self._multi_step_incompat_req_count == 0) + + def _is_multi_step_temporarily_disabled(self) -> bool: + """The engine supports multi-step; the current step disabled it.""" + return (self._engine_permits_multi_step_scheduling + and not self._current_step_is_multi_step) + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) + # Detect & count seq groups incompatible with multi-step + self._maybe_record_new_sg_w_multi_step_incompat_sample_params( + seq_group) def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the running queue. @@ -521,7 +583,7 @@ def _schedule_running( ret.preempted.clear() ret.swapped_out.clear() - ret.num_lookahead_slots = self._get_num_lookahead_slots( + ret.num_lookahead_slots = self._get_current_step_num_lookahead_slots( is_prefill=False, enable_chunking=enable_chunking) ret.decode_seq_groups_list.clear() @@ -685,7 +747,8 @@ def _schedule_swapped( is_prefill = seq_group.is_prefill() alloc_status = self.block_manager.can_swap_in( seq_group, - self._get_num_lookahead_slots(is_prefill, enable_chunking)) + self._get_current_step_num_lookahead_slots( + is_prefill, enable_chunking)) if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -747,14 +810,14 @@ def _schedule_swapped( prefill_seq_groups=prefill_seq_groups, blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, - num_lookahead_slots=self._get_num_lookahead_slots( + num_lookahead_slots=self._get_current_step_num_lookahead_slots( is_prefill=False, enable_chunking=enable_chunking), infeasible_seq_groups=infeasible_seq_groups, ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: if self.scheduler_config.chunked_prefill_enabled and \ - not self.scheduler_config.is_multi_step: + not self._current_step_is_multi_step: prompt_limit = self.scheduler_config.max_model_len else: prompt_limit = min(self.scheduler_config.max_model_len, @@ -902,9 +965,10 @@ def _schedule_prefills( continue num_lookahead_slots: int = 0 - if self.scheduler_config.is_multi_step and enable_chunking: - num_lookahead_slots = self._get_num_lookahead_slots( - True, enable_chunking) + if (self._current_step_is_multi_step and enable_chunking): + num_lookahead_slots = ( + self._get_current_step_num_lookahead_slots( + True, enable_chunking)) # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate( @@ -948,7 +1012,7 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) - if enable_chunking and self.scheduler_config.is_multi_step: + if (enable_chunking and self._current_step_is_multi_step): blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots self._append_slots(seq_group, blocks_to_copy, enable_chunking) @@ -962,7 +1026,8 @@ def _schedule_prefills( num_lookahead_slots, num_scheduler_steps=self.scheduler_config. num_scheduler_steps, - is_multi_step=self.scheduler_config.is_multi_step, + is_multi_step=self.scheduler_config. + current_step_is_multi_step, enable_chunking=enable_chunking) seq_groups.append( @@ -979,7 +1044,7 @@ def _schedule_prefills( return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( + num_lookahead_slots=self._get_current_step_num_lookahead_slots( is_prefill=True, enable_chunking=enable_chunking)) def _schedule_default(self) -> SchedulerOutputs: @@ -990,6 +1055,7 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ + # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -1172,6 +1238,10 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" + # Configure the (non-)use of multi-step scheduling + # in this step + self._maybe_disable_multi_step_by_sampling_params() + # Choose appropriate scheduler if self.scheduler_config.chunked_prefill_enabled: return self._schedule_chunked_prefill() else: @@ -1190,13 +1260,13 @@ def _can_append_slots(self, seq_group: SequenceGroup, return False is_prefill = seq_group.is_prefill() - num_lookahead_slots = self._get_num_lookahead_slots( + num_lookahead_slots = self._get_current_step_num_lookahead_slots( is_prefill, enable_chunking) if is_prefill and num_lookahead_slots > 0: # Appending prefill slots only happens multi-step and # chunked-prefill are enabled together. - assert self.scheduler_config.is_multi_step and enable_chunking + assert (self._current_step_is_multi_step and enable_chunking) return self.block_manager.can_append_slots( seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) @@ -1394,6 +1464,11 @@ def free_finished_seq_groups(self) -> None: self._free_finished_seq_group(seq_group) if not seq_group.is_finished(): remaining.append(seq_group) + else: + # If seq group had best_of>1 & is finished, decrement counter of + # seq groups incompatible with multi-step + self._maybe_record_finished_sg_w_multi_step_incompat_sample_params( + seq_group) self.running = remaining @@ -1406,6 +1481,10 @@ def free_finished_seq_groups(self) -> None: # Free finished seqs self._free_finished_seqs(seq_group) + # If seq group had best_of>1 & is finished, decrement counter of + # seq groups incompatible with multi-step + self._maybe_record_finished_sg_w_multi_step_incompat_sample_params( + seq_group) self._async_stopped.clear() @@ -1431,17 +1510,17 @@ def _append_slots(self, enable_chunking (bool): True if chunked prefill is enabled. """ is_prefill: bool = seq_group.is_prefill() - num_lookahead_slots: int = self._get_num_lookahead_slots( + num_lookahead_slots: int = self._get_current_step_num_lookahead_slots( is_prefill, enable_chunking) seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, num_scheduler_steps=self.scheduler_config.num_scheduler_steps, - is_multi_step=self.scheduler_config.is_multi_step, + is_multi_step=self._current_step_is_multi_step, enable_chunking=enable_chunking) seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING - if self.scheduler_config.is_multi_step and enable_chunking: + if self._current_step_is_multi_step and enable_chunking: # In multi-step chunked-prefill any sequence type can have # slots appended. seq_status = None @@ -1557,8 +1636,8 @@ def _passed_delay(self, now: float) -> bool: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool, - enable_chunking: bool) -> int: + def _get_current_step_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. @@ -1570,8 +1649,11 @@ def _get_num_lookahead_slots(self, is_prefill: bool, for the prefills for when the prefills turn into decodes in the first step. """ + if self._is_multi_step_temporarily_disabled(): + return 0 + if is_prefill: - if self.scheduler_config.is_multi_step and enable_chunking: + if (self._current_step_is_multi_step and enable_chunking): # num_lookahead_slots was introduced in the context of decodes, # in Speculative Decoding. # When the num_scheduler_steps is 8, say, then the @@ -1609,7 +1691,7 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: + if self._current_step_is_multi_step: # The current multi-step + chunked prefill capability does # not actually support chunking prompts. # diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1f57aecb6481..9c0743d6e7a9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -303,7 +303,7 @@ async def step_async( if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) - if (self.scheduler_config.is_multi_step + if (self.scheduler_config.current_step_is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots @@ -348,7 +348,7 @@ async def step_async( # we need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: self._update_cached_scheduler_output(virtual_engine, outputs) else: if len(ctx.output_queue) > 0: @@ -356,13 +356,13 @@ async def step_async( outputs = [] # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: for seq_group in seq_group_metadata_list: seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): # Clear the cache if we have finished all the steps - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1dd0f097c74f..db32df760e85 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -464,9 +464,20 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Create sequence output processor, e.g. for beam search or # speculative decoding. + self.output_processor = ( - SequenceGroupOutputProcessor.create_output_processor( + SequenceGroupOutputProcessor.create_single_step_output_processor( self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + ) + ) if self.scheduler_config.num_lookahead_slots == 0 else ( + SequenceGroupOutputProcessor.create_multi_step_output_processor( self.detokenizer, self.scheduler, self.seq_counter, @@ -477,6 +488,23 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) + if self.scheduler_config.engine_permits_multi_step_scheduling: + # Multi-step only: construct a fallback single-step output + # processor for scenarios where a request utilizes a feature + # unsupported by multi-step + self.fallback_single_step_output_processor = ( + SequenceGroupOutputProcessor. + create_single_step_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + )) + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} def _initialize_kv_caches(self) -> None: @@ -1010,7 +1038,7 @@ def _update_num_computed_tokens_for_multi_step_prefill( in multi-step are submitted in a single burst. """ - assert self.scheduler_config.is_multi_step + assert self.scheduler_config.current_step_is_multi_step if not seq_group_meta.is_prompt: # num_computed_token updates for multi-step decodes happen after @@ -1065,7 +1093,7 @@ def _process_model_outputs(self, has_multiple_outputs: bool = len(outputs) > 1 outputs_by_sequence_group: List[List[SequenceGroupOutput]] if has_multiple_outputs: - assert self.scheduler_config.is_multi_step or \ + assert self.scheduler_config.current_step_is_multi_step or \ self.speculative_config # Organize outputs by [step][sequence group] instead of # [sequence group][step]. @@ -1116,7 +1144,7 @@ def _process_model_outputs(self, output = [outputs_by_sequence_group[0][i]] if not is_async: - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: # Updates happen only if the sequence is prefill self._update_num_computed_tokens_for_multi_step_prefill( seq_group, seq_group_meta, is_first_step_output) @@ -1144,9 +1172,14 @@ def _process_model_outputs(self, if self.model_config.task == "embedding": self._process_sequence_group_outputs(seq_group, output) else: - self.output_processor.process_prompt_logprob(seq_group, output) + selected_output_processor = ( + self.fallback_single_step_output_processor + if self._should_force_single_step() else + self.output_processor) + selected_output_processor.process_prompt_logprob( + seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs( + selected_output_processor.process_outputs( seq_group, output, is_async) if seq_group.is_finished(): @@ -1263,7 +1296,7 @@ def _advance_to_next_step( if seq_group.is_finished(): continue - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: # Updates happen only if the sequence is prefill self._update_num_computed_tokens_for_multi_step_prefill( seq_group, seq_group_metadata, @@ -1283,7 +1316,7 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: is_prefill_append = seq.data.get_num_uncomputed_tokens( ) == 0 seq.append_token_id(sample.output_token, sample.logprobs) @@ -1292,6 +1325,11 @@ def _advance_to_next_step( else: seq.append_token_id(sample.output_token, sample.logprobs) + def _should_force_single_step(self) -> bool: + """True if user configured multi-step but there is a best_of > 1 req""" + return (self.scheduler_config.engine_permits_multi_step_scheduling + and (not self.scheduler_config.current_step_is_multi_step)) + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1380,7 +1418,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) - if (self.scheduler_config.is_multi_step + if (self.scheduler_config.current_step_is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots @@ -1412,7 +1450,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: finished_requests_ids=finished_requests_ids, # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) + last_sampled_token_ids=last_sampled_token_ids, + force_single_step=self._should_force_single_step()) if allow_async_output_proc: execute_model_req.async_callback = self.async_callbacks[ @@ -1423,7 +1462,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: self._update_cached_scheduler_output(virtual_engine, outputs) else: # Nothing scheduled => If there is pending async postprocessor, @@ -1434,13 +1473,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: outputs = [] # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: for seq_group in seq_group_metadata_list: seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): # clear the cache if we have finished all the steps. - if self.scheduler_config.is_multi_step: + if self.scheduler_config.current_step_is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() # is_first_step_output is True only when the num_steps of all @@ -1497,7 +1536,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] ) -> bool: - if (not self.scheduler_config.is_multi_step + if (not self.scheduler_config.current_step_is_multi_step or not seq_group_metadata_list): return False @@ -1543,7 +1582,7 @@ def _get_last_sampled_token_ids( self, virtual_engine: int) -> Optional[torch.Tensor]: cached_last_output = self.cached_scheduler_outputs[ virtual_engine].last_output - if (self.scheduler_config.is_multi_step + if (self.scheduler_config.current_step_is_multi_step and self.parallel_config.pipeline_parallel_size > 1 and cached_last_output is not None and cached_last_output.sampled_token_ids_cpu is not None): diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 50adaf4e5918..aa612689d97e 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -23,37 +23,39 @@ class SequenceGroupOutputProcessor(ABC): """ @staticmethod - def create_output_processor( - scheduler_config: SchedulerConfig, + def create_multi_step_output_processor( detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], stop_checker: "StopChecker", ): - """Create an output processor. + """Create a multi-step output processor.""" + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) - This returns a single-step output processor if num_lookahead_slots is - zero, else returns a multi-step output processor. - """ - if scheduler_config.num_lookahead_slots == 0: - # Importing here to avoid cycle. - from vllm.engine.output_processor.single_step import ( - SingleStepOutputProcessor) - return SingleStepOutputProcessor(scheduler_config, detokenizer, - scheduler, seq_counter, - stop_checker) - else: - # Importing here to avoid cycle. - from vllm.engine.output_processor.multi_step import ( - MultiStepOutputProcessor) - return MultiStepOutputProcessor( - detokenizer, - scheduler, - seq_counter, - get_tokenizer_for_seq, - stop_checker, - ) + @staticmethod + def create_single_step_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + stop_checker: "StopChecker", + ): + """Create a single-step output processor.""" + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, stop_checker) @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index db97fe0a0285..7dfba80c5d33 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -14,6 +14,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) +from vllm.entrypoints.utils import STR_MULTI_STEP_BEAM_SEARCH_NOT_SUPPORTED from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -402,6 +403,10 @@ def beam_search( penalty, and stopping criteria, etc.? """ + if (self.llm_engine.scheduler_config. + engine_permits_multi_step_scheduling): + raise ValueError(STR_MULTI_STEP_BEAM_SEARCH_NOT_SUPPORTED) + beam_width = params.beam_width max_tokens = params.max_tokens temperature = params.temperature diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py new file mode 100644 index 000000000000..ee0fe1a25b52 --- /dev/null +++ b/vllm/entrypoints/utils.py @@ -0,0 +1,5 @@ +"""Utilities for entrypoints""" + +STR_MULTI_STEP_BEAM_SEARCH_NOT_SUPPORTED = ( + "Currently beam search is not supported in " + "combination with multi-step scheduling.") diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index ed30d3186a45..c0f194f9b672 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -69,7 +69,7 @@ def _get_worker_kwargs( def _get_worker_module_and_class( self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: worker_class_fn = None - if self.scheduler_config.is_multi_step: + if self.scheduler_config.engine_permits_multi_step_scheduling: worker_module_name = "vllm.worker.multi_step_worker" worker_class_name = "MultiStepWorker" elif self.speculative_config: diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index d02fecb46f00..899230d5d350 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -70,7 +70,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", ) assert self.speculative_config is None - if self.scheduler_config.is_multi_step: + if self.scheduler_config.engine_permits_multi_step_scheduling: worker_module_name = "vllm.worker.multi_step_tpu_worker" worker_class_name = "MultiStepTPUWorker" else: diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 972649dedf33..9fb6f5564d48 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -62,7 +62,7 @@ def _create_worker( rank: int = 0, distributed_init_method: Optional[str] = None, ): - if self.scheduler_config.is_multi_step: + if self.scheduler_config.engine_permits_multi_step_scheduling: from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker worker = MultiStepTPUWorker(**self._get_worker_kwargs( local_rank, rank, distributed_init_method)) diff --git a/vllm/sequence.py b/vllm/sequence.py index fc936fbab0ea..b138f96f43bc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1290,6 +1290,8 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + # Force single-step scheduling even if multi-step is enabled for the engine + force_single_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1336,7 +1338,8 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) + async_callback=self.async_callback, + force_single_step=self.force_single_step) @dataclass diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8b74f06e77be..848a2d9485d0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -476,7 +476,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.is_multi_step or \ + elif self.runner.scheduler_config.current_step_is_multi_step or \ self.runner.model_config.is_encoder_decoder_model: context_len = seq_len - 1 else: @@ -748,8 +748,9 @@ def _get_cuda_graph_pad_size(self, int: Returns the determined number of padding sequences. If CUDA graphs is not viable, returns -1. """ - is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ - self.runner.scheduler_config.chunked_prefill_enabled + is_mscp: bool = ( + self.runner.scheduler_config.current_step_is_multi_step + and self.runner.scheduler_config.chunked_prefill_enabled) decode_only = self.decode_only or is_mscp if not decode_only: # Early exit so we can treat num_seqs as the batch_size below. diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index be2f0d79154d..4bff6e8ae0e3 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -323,7 +323,8 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): # multi-step logic self._base_model_runner: GPUModelRunnerBase = base_model_runner - self.is_multi_step = self.scheduler_config.is_multi_step + self.is_multi_step = ( + self.scheduler_config.engine_permits_multi_step_scheduling) self.pinned_sampled_token_ids: Optional[torch.Tensor] = None # Using the PythonizationCache in Pipeline-Parallel clobbers the @@ -469,6 +470,11 @@ def execute_model( # path for warm up runs if not model_input.is_multi_step: + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + False) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + False) return self._base_model_runner.execute_model( frozen_model_input, kv_caches, intermediate_tensors, num_steps) diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index bf66f32d7d24..9a55cfb2430b 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -81,6 +81,9 @@ def _get_driver_input_and_broadcast( frozen_model_input.attn_metadata._cached_prefill_metadata = None frozen_model_input.attn_metadata._cached_decode_metadata = None + if execute_model_req.force_single_step: + model_input.is_multi_step = False + model_input.is_first_multi_step = is_first_multi_step model_input.is_last_step = execute_model_req.is_last_step