Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
assert len(scheduler.waiting) == 1


def test_priority_scheduling_preemption_when_out_of_kv():
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
Expand All @@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
max_num_batched_tokens=200,
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
)

# Create a request and schedule it
Expand All @@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
starting_idx=0,
)[0]
scheduler.add_request(request_low)
# 1st schedule
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1

# Simulate model execution
# Simulate model execution - 1st decode
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
Expand All @@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
starting_idx=1,
)[0]
scheduler.add_request(request_high)
# 2nd schedule
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
Expand All @@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2

# Simulate model execution
# Simulate model execution - 2nd decode
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
Expand All @@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
)
scheduler.update_from_output(output, model_output)

# Schedule again - this should trigger preemption
# 3rd schedule - this should trigger preemption
# req_low needs 32 tokens = 2 blocks
# req_high needs 33 tokens = 3 blocks
# so doesn't fit in 4 blocks.
Expand All @@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1

# Simulate model execution - 3rd decode
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
sampled_token_ids=[[], [100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Finish the requests to make room for the preempted requests to resume
scheduler.update_from_output(output, model_output)
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)

# 4th Schedule - this should trigger the resumption
output = scheduler.schedule()
scheduled_cached_reqs = output.scheduled_cached_reqs
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption

assert len(output.scheduled_new_reqs) == 0
assert scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1

# Preempted request resumed in scheduled_cached_reqs
assert len(resumed_from_preemption) == 1
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
assert resumed_from_preemption[0]
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100


@pytest.mark.parametrize(
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
Expand Down
1 change: 1 addition & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
req_ids=[req_id],
resumed_from_preemption=[False],
new_token_ids=[[]],
resumed_req_token_ids=[None],
new_block_ids=([[0]],),
num_computed_tokens=[0],
num_output_tokens=[0],
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
# If resumed_from_preemption is True, propogate the token ids to the
# connector, otherwise will be empty.
resumed_req_token_ids: list[list[int] | None]

Check notice on line 103 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_req_token_ids was added

Check notice on line 103 in vllm/v1/core/sched/output.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function CachedRequestData: resumed_req_token_ids was added
new_block_ids: list[tuple[list[int], ...] | None]
num_computed_tokens: list[int]
num_output_tokens: list[int]
Expand All @@ -112,6 +115,7 @@
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
resumed_req_token_ids=[],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a proper test case to make sure this is populated correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Created a new test case of preemption->resumption including this field.

pytest tests/v1/core/test_scheduler.py::test_priority_scheduling_preemption_and_resumption_when_out_of_kv

new_block_ids=[],
num_computed_tokens=[],
num_output_tokens=[],
Expand Down
18 changes: 13 additions & 5 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,10 +709,15 @@ def _make_cached_request_data(
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...] | None] = []
resumed_req_token_ids: list[list[int] | None] = []
num_computed_tokens: list[int] = []
num_output_tokens: list[int] = []

for req in itertools.chain(running_reqs, resumed_reqs):
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = num_scheduled_tokens[req_id] - len(
Expand All @@ -728,20 +733,23 @@ def _make_cached_request_data(
req.num_computed_tokens : req.num_computed_tokens + num_tokens
]
new_token_ids.append(token_ids)
resumed_token_ids = None
if resumed_from_preemption[idx]:
resumed_token_ids = req.all_token_ids[
: req.num_computed_tokens + num_tokens
]
resumed_req_token_ids.append(resumed_token_ids)
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
num_computed_tokens.append(req.num_computed_tokens)
num_output_tokens.append(req.num_output_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)

return CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
resumed_req_token_ids=resumed_req_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_output_tokens=num_output_tokens,
Expand Down