diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 5baa2a1be4ab..dfa965c56766 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(scheduler.waiting) == 1 -def test_priority_scheduling_preemption_when_out_of_kv(): +def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): """Test that priority scheduling preempts lower priority requests when out of KV cache space.""" # Create scheduler with very limited memory to force preemption @@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): max_num_batched_tokens=200, num_blocks=5, # Can hold 64 tokens (first block is null) block_size=16, # Standard block size + use_kv_connector=True, ) # Create a request and schedule it @@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv(): starting_idx=0, )[0] scheduler.add_request(request_low) + # 1st schedule output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 1 - # Simulate model execution + # Simulate model execution - 1st decode model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, @@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): starting_idx=1, )[0] scheduler.add_request(request_high) + # 2nd schedule output = scheduler.schedule() # KV cache should be full at this point assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 @@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 2 - # Simulate model execution + # Simulate model execution - 2nd decode requests = [request_low, request_high] model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): ) scheduler.update_from_output(output, model_output) - # Schedule again - this should trigger preemption + # 3rd schedule - this should trigger preemption # req_low needs 32 tokens = 2 blocks # req_high needs 33 tokens = 3 blocks # so doesn't fit in 4 blocks. @@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv(): assert len(output.scheduled_new_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 + # Simulate model execution - 3rd decode + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[], [100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + # Finish the requests to make room for the preempted requests to resume + scheduler.update_from_output(output, model_output) + scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED) + + # 4th Schedule - this should trigger the resumption + output = scheduler.schedule() + scheduled_cached_reqs = output.scheduled_cached_reqs + resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption + + assert len(output.scheduled_new_reqs) == 0 + assert scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Preempted request resumed in scheduled_cached_reqs + assert len(resumed_from_preemption) == 1 + assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1 + assert resumed_from_preemption[0] + assert scheduled_cached_reqs.req_ids[0] == request_low.request_id + assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None + # Resumed tokens include 30 prompt tokens and 2 decoded tokens + assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32 + assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100 + @pytest.mark.parametrize( ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index ef2956bd3ec2..5afbaccd48bb 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -257,6 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init): req_ids=[req_id], resumed_from_preemption=[False], new_token_ids=[[]], + resumed_req_token_ids=[None], new_block_ids=([[0]],), num_computed_tokens=[0], num_output_tokens=[0], diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index cbce91b990a1..981c5e9c7636 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -98,6 +98,9 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] + # If resumed_from_preemption is True, propogate the token ids to the + # connector, otherwise will be empty. + resumed_req_token_ids: list[list[int] | None] new_block_ids: list[tuple[list[int], ...] | None] num_computed_tokens: list[int] num_output_tokens: list[int] @@ -112,6 +115,7 @@ def make_empty(cls) -> CachedRequestData: req_ids=[], resumed_from_preemption=[], new_token_ids=[], + resumed_req_token_ids=[], new_block_ids=[], num_computed_tokens=[], num_output_tokens=[], diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f81750047ecc..6829fed33e45 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -709,10 +709,15 @@ def _make_cached_request_data( req_ids: list[str] = [] new_token_ids: list[list[int]] = [] new_block_ids: list[tuple[list[int], ...] | None] = [] + resumed_req_token_ids: list[list[int] | None] = [] num_computed_tokens: list[int] = [] num_output_tokens: list[int] = [] - for req in itertools.chain(running_reqs, resumed_reqs): + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) num_tokens = num_scheduled_tokens[req_id] - len( @@ -728,20 +733,23 @@ def _make_cached_request_data( req.num_computed_tokens : req.num_computed_tokens + num_tokens ] new_token_ids.append(token_ids) + resumed_token_ids = None + if resumed_from_preemption[idx]: + resumed_token_ids = req.all_token_ids[ + : req.num_computed_tokens + num_tokens + ] + resumed_req_token_ids.append(resumed_token_ids) new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(req.num_output_tokens) - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) - resumed_from_preemption += [True] * len(resumed_reqs) return CachedRequestData( req_ids=req_ids, resumed_from_preemption=resumed_from_preemption, new_token_ids=new_token_ids, + resumed_req_token_ids=resumed_req_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, num_output_tokens=num_output_tokens,