Skip to content

Commit b3e081a

Browse files
elfieggmzusman
authored andcommitted
[Misc] Add multipstep chunked-prefill support for FlashInfer (vllm-project#10467)
1 parent 253ce05 commit b3e081a

File tree

5 files changed

+169
-109
lines changed

5 files changed

+169
-109
lines changed

csrc/prepare_inputs/advance_step.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
9595
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
9696
int const* block_tables_ptr, int64_t const block_tables_stride,
9797
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
98+
int const n_pad = num_seqs - num_queries;
99+
if (n_pad && blockIdx.x == 0) {
100+
// Handle cuda graph padding
101+
int const offset = num_queries;
102+
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
103+
input_tokens_ptr[offset + i] = 0;
104+
input_positions_ptr[offset + i] = 0;
105+
slot_mapping_ptr[offset + i] = -1;
106+
}
107+
}
98108
int num_query_blocks = div_ceil(num_queries, num_threads);
99109

100110
if (blockIdx.x < num_query_blocks) {

tests/multi_step/test_correctness_llm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77

8+
from tests.kernels.utils import override_backend_env_variable
9+
810
from ..models.utils import check_logprobs_close, check_outputs_equal
911

1012
MODELS = [
@@ -19,10 +21,11 @@
1921
@pytest.mark.parametrize("tp_size", [1])
2022
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
2123
@pytest.mark.parametrize("max_tokens", [5])
22-
@pytest.mark.parametrize("enforce_eager", [True])
24+
@pytest.mark.parametrize("enforce_eager", [True, False])
2325
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
2426
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
2527
@pytest.mark.parametrize("num_logprobs", [None, 5])
28+
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
2629
def test_multi_step_llm(
2730
hf_runner,
2831
vllm_runner,
@@ -36,6 +39,8 @@ def test_multi_step_llm(
3639
num_scheduler_steps: int,
3740
num_prompts: int,
3841
num_logprobs: Optional[int],
42+
attention_backend: str,
43+
monkeypatch,
3944
) -> None:
4045
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
4146
@@ -63,6 +68,7 @@ def test_multi_step_llm(
6368
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
6469
completions endpoint; `None` -> 1 logprob returned.
6570
"""
71+
override_backend_env_variable(monkeypatch, attention_backend)
6672

6773
prompts = example_prompts
6874
if len(prompts) < num_prompts:
@@ -114,6 +120,7 @@ def test_multi_step_llm(
114120
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
115121
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
116122
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
123+
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
117124
def test_multi_step_llm_w_prompt_logprobs(
118125
vllm_runner,
119126
example_prompts,
@@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
126133
num_prompts: int,
127134
num_logprobs: Optional[int],
128135
num_prompt_logprobs: Optional[int],
136+
attention_backend: str,
137+
monkeypatch,
129138
) -> None:
130139
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
131140
@@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
155164
note that this argument is not supported by the
156165
OpenAI completions endpoint.
157166
"""
167+
override_backend_env_variable(monkeypatch, attention_backend)
158168

159169
prompts = example_prompts
160170
if len(prompts) < num_prompts:
@@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
205215
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
206216
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
207217
@pytest.mark.parametrize("num_logprobs", [None, 5])
218+
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
208219
def test_multi_step_llm_chunked_prefill_prefix_cache(
209220
vllm_runner,
210221
example_prompts,
@@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
216227
num_scheduler_steps: int,
217228
num_prompts: int,
218229
num_logprobs: Optional[int],
230+
attention_backend: str,
231+
monkeypatch,
219232
) -> None:
220233
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
221234
@@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
278291
#
279292
# The Incorrect scheduling behavior - if it occurs - will cause an exception
280293
# in the model runner resulting from `do_sample=False`.
294+
override_backend_env_variable(monkeypatch, attention_backend)
295+
281296
assert len(example_prompts) >= 2
282297
challenge_prompts = copy.deepcopy(example_prompts)
283298
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '

vllm/attention/backends/flashinfer.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,12 @@ def prepare_graph_input_buffers(self,
256256
def begin_forward(self, model_input):
257257
assert not self._is_graph_capturing
258258
state = self
259-
if model_input.attn_metadata.use_cuda_graph:
259+
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
260+
is_decode = model_input.attn_metadata.num_prefills == 0
261+
# In case of multistep chunked-prefill, there might be prefill requests
262+
# scheduled while CUDA graph mode is enabled. We don't run graph in that
263+
# case.
264+
if use_cuda_graph and is_decode:
260265
batch_size = model_input.input_tokens.shape[0]
261266
state = (self.runner.graph_runners[model_input.virtual_engine]
262267
[batch_size].attn_state)
@@ -429,10 +434,24 @@ def advance_step(self,
429434
Update metadata in-place to advance one decode step.
430435
"""
431436

432-
assert not turn_prefills_into_decodes, \
433-
("Chunked prefill is not supported with flashinfer yet."
434-
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
435-
"specific parameter.")
437+
if turn_prefills_into_decodes:
438+
# When Multi-Step is enabled with Chunked-Prefill, prefills and
439+
# decodes are scheduled together. In the first step, all the
440+
# prefills turn into decodes. This update reflects that
441+
# conversion.
442+
assert self.num_decode_tokens + self.num_prefills == num_seqs
443+
# Flashinfer doesn't support speculative decoding + chunked-prefill
444+
# + multi-step scheduling yet.
445+
assert self.decode_query_len == 1
446+
self.num_decode_tokens += self.num_prefills
447+
self.num_prefills = 0
448+
self.num_prefill_tokens = 0
449+
self.max_prefill_seq_len = 0
450+
self.max_query_len = 1
451+
452+
self.slot_mapping = self.slot_mapping[:num_seqs]
453+
else:
454+
assert self.seq_lens_tensor is not None
436455

437456
assert num_seqs > 0
438457
assert num_queries > 0

vllm/worker/model_runner.py

Lines changed: 118 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
import warnings
77
import weakref
8+
from contextlib import contextmanager
89
from dataclasses import dataclass
910
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
1011
Tuple, Type, TypeVar, Union)
@@ -1028,6 +1029,8 @@ def __init__(
10281029

10291030
self.has_inner_state = model_config.has_inner_state
10301031

1032+
self.in_profile_run = False
1033+
10311034
# When using CUDA graph, the input block tables must be padded to
10321035
# max_seq_len_to_capture. However, creating the block table in
10331036
# Python can be expensive. To optimize this, we cache the block table
@@ -1228,110 +1231,123 @@ def _prepare_model_input_tensors(
12281231

12291232
return builder.build() # type: ignore
12301233

1234+
@contextmanager
1235+
def set_in_profile_run(self):
1236+
self.in_profile_run = True
1237+
try:
1238+
yield
1239+
finally:
1240+
self.in_profile_run = False
1241+
12311242
@torch.inference_mode()
12321243
def profile_run(self) -> None:
1233-
# Enable top-k sampling to reflect the accurate memory usage.
1234-
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
1235-
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
1236-
max_num_seqs = self.scheduler_config.max_num_seqs
1237-
# This represents the maximum number of different requests
1238-
# that will have unique loras, an therefore the max amount of memory
1239-
# consumption create dummy lora request copies from the lora request
1240-
# passed in, which contains a lora from the lora warmup path.
1241-
dummy_lora_requests: List[LoRARequest] = []
1242-
dummy_lora_requests_per_seq: List[LoRARequest] = []
1243-
if self.lora_config:
1244-
assert self.lora_manager is not None
1245-
with self.lora_manager.dummy_lora_cache():
1246-
for idx in range(self.lora_config.max_loras):
1247-
lora_id = idx + 1
1248-
dummy_lora_request = LoRARequest(
1249-
lora_name=f"warmup_{lora_id}",
1250-
lora_int_id=lora_id,
1251-
lora_path="/not/a/real/path",
1252-
)
1253-
self.lora_manager.add_dummy_lora(dummy_lora_request,
1254-
rank=LORA_WARMUP_RANK)
1255-
dummy_lora_requests.append(dummy_lora_request)
1256-
dummy_lora_requests_per_seq = [
1257-
dummy_lora_requests[idx % len(dummy_lora_requests)]
1258-
for idx in range(max_num_seqs)
1259-
]
1260-
1261-
# Profile memory usage with max_num_sequences sequences and the total
1262-
# number of tokens equal to max_num_batched_tokens.
1263-
seqs: List[SequenceGroupMetadata] = []
1264-
# Additional GPU memory may be needed for multi-modal encoding, which
1265-
# needs to be accounted for when calculating the GPU blocks for
1266-
# vLLM blocker manager.
1267-
# To exercise the worst scenario for GPU memory consumption,
1268-
# the number of seqs (batch_size) is chosen to maximize the number
1269-
# of images processed.
1270-
1271-
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
1272-
self.model_config)
1273-
if max_mm_tokens > 0:
1274-
max_num_seqs_orig = max_num_seqs
1275-
max_num_seqs = min(max_num_seqs,
1276-
max_num_batched_tokens // max_mm_tokens)
1277-
if max_num_seqs < 1:
1278-
expr = (f"min({max_num_seqs_orig}, "
1279-
f"{max_num_batched_tokens} // {max_mm_tokens})")
1280-
logger.warning(
1281-
"Computed max_num_seqs (%s) to be less than 1. "
1282-
"Setting it to the minimum value of 1.", expr)
1283-
max_num_seqs = 1
1284-
1285-
batch_size = 0
1286-
for group_id in range(max_num_seqs):
1287-
seq_len = (max_num_batched_tokens // max_num_seqs +
1288-
(group_id < max_num_batched_tokens % max_num_seqs))
1289-
batch_size += seq_len
1290-
1291-
dummy_data = self.input_registry \
1292-
.dummy_data_for_profiling(self.model_config,
1293-
seq_len,
1294-
self.mm_registry)
1295-
1296-
seq = SequenceGroupMetadata(
1297-
request_id=str(group_id),
1298-
is_prompt=True,
1299-
seq_data={group_id: dummy_data.seq_data},
1300-
sampling_params=sampling_params,
1301-
block_tables=None,
1302-
lora_request=dummy_lora_requests_per_seq[group_id]
1303-
if dummy_lora_requests_per_seq else None,
1304-
multi_modal_data=dummy_data.multi_modal_data,
1305-
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
1306-
)
1307-
seqs.append(seq)
1308-
1309-
# Run the model with the dummy inputs.
1310-
num_layers = self.model_config.get_num_layers(self.parallel_config)
1311-
# use an empty tensor instead of `None`` to force Dynamo to pass
1312-
# it by reference, rather by specializing on the value ``None``.
1313-
# the `dtype` argument does not matter, and we use `float32` as
1314-
# a placeholder (it has wide hardware support).
1315-
# it is important to create tensors inside the loop, rather than
1316-
# multiplying the list, to avoid Dynamo from treating them as
1317-
# tensor aliasing.
1318-
kv_caches = [
1319-
torch.tensor([], dtype=torch.float32, device=self.device)
1320-
for _ in range(num_layers)
1321-
]
1322-
finished_requests_ids = [seq.request_id for seq in seqs]
1323-
model_input = self.prepare_model_input(
1324-
seqs, finished_requests_ids=finished_requests_ids)
1325-
intermediate_tensors = None
1326-
if not get_pp_group().is_first_rank:
1327-
intermediate_tensors = self.model.make_empty_intermediate_tensors(
1328-
batch_size=batch_size,
1329-
dtype=self.model_config.dtype,
1330-
device=self.device)
1331-
1332-
self.execute_model(model_input, kv_caches, intermediate_tensors)
1333-
torch.cuda.synchronize()
1334-
return
1244+
with self.set_in_profile_run():
1245+
# Enable top-k sampling to reflect the accurate memory usage.
1246+
sampling_params = \
1247+
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
1248+
max_num_batched_tokens = \
1249+
self.scheduler_config.max_num_batched_tokens
1250+
max_num_seqs = self.scheduler_config.max_num_seqs
1251+
# This represents the maximum number of different requests
1252+
# that will have unique loras, an therefore the max amount of memory
1253+
# consumption create dummy lora request copies from the lora request
1254+
# passed in, which contains a lora from the lora warmup path.
1255+
dummy_lora_requests: List[LoRARequest] = []
1256+
dummy_lora_requests_per_seq: List[LoRARequest] = []
1257+
if self.lora_config:
1258+
assert self.lora_manager is not None
1259+
with self.lora_manager.dummy_lora_cache():
1260+
for idx in range(self.lora_config.max_loras):
1261+
lora_id = idx + 1
1262+
dummy_lora_request = LoRARequest(
1263+
lora_name=f"warmup_{lora_id}",
1264+
lora_int_id=lora_id,
1265+
lora_path="/not/a/real/path",
1266+
)
1267+
self.lora_manager.add_dummy_lora(dummy_lora_request,
1268+
rank=LORA_WARMUP_RANK)
1269+
dummy_lora_requests.append(dummy_lora_request)
1270+
dummy_lora_requests_per_seq = [
1271+
dummy_lora_requests[idx % len(dummy_lora_requests)]
1272+
for idx in range(max_num_seqs)
1273+
]
1274+
1275+
# Profile memory usage with max_num_sequences sequences and the
1276+
# total number of tokens equal to max_num_batched_tokens.
1277+
seqs: List[SequenceGroupMetadata] = []
1278+
# Additional GPU memory may be needed for multi-modal encoding,
1279+
# which needs to be accounted for when calculating the GPU blocks
1280+
# for vLLM blocker manager.
1281+
# To exercise the worst scenario for GPU memory consumption,
1282+
# the number of seqs (batch_size) is chosen to maximize the number
1283+
# of images processed.
1284+
1285+
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
1286+
self.model_config)
1287+
if max_mm_tokens > 0:
1288+
max_num_seqs_orig = max_num_seqs
1289+
max_num_seqs = min(max_num_seqs,
1290+
max_num_batched_tokens // max_mm_tokens)
1291+
if max_num_seqs < 1:
1292+
expr = (f"min({max_num_seqs_orig}, "
1293+
f"{max_num_batched_tokens} // {max_mm_tokens})")
1294+
logger.warning(
1295+
"Computed max_num_seqs (%s) to be less than 1. "
1296+
"Setting it to the minimum value of 1.", expr)
1297+
max_num_seqs = 1
1298+
1299+
batch_size = 0
1300+
for group_id in range(max_num_seqs):
1301+
seq_len = (max_num_batched_tokens // max_num_seqs +
1302+
(group_id < max_num_batched_tokens % max_num_seqs))
1303+
batch_size += seq_len
1304+
1305+
dummy_data = self.input_registry \
1306+
.dummy_data_for_profiling(self.model_config,
1307+
seq_len,
1308+
self.mm_registry)
1309+
1310+
seq = SequenceGroupMetadata(
1311+
request_id=str(group_id),
1312+
is_prompt=True,
1313+
seq_data={group_id: dummy_data.seq_data},
1314+
sampling_params=sampling_params,
1315+
block_tables=None,
1316+
lora_request=dummy_lora_requests_per_seq[group_id]
1317+
if dummy_lora_requests_per_seq else None,
1318+
multi_modal_data=dummy_data.multi_modal_data,
1319+
multi_modal_placeholders=dummy_data.
1320+
multi_modal_placeholders,
1321+
)
1322+
seqs.append(seq)
1323+
1324+
# Run the model with the dummy inputs.
1325+
num_layers = self.model_config.get_num_layers(self.parallel_config)
1326+
# use an empty tensor instead of `None`` to force Dynamo to pass
1327+
# it by reference, rather by specializing on the value ``None``.
1328+
# the `dtype` argument does not matter, and we use `float32` as
1329+
# a placeholder (it has wide hardware support).
1330+
# it is important to create tensors inside the loop, rather than
1331+
# multiplying the list, to avoid Dynamo from treating them as
1332+
# tensor aliasing.
1333+
kv_caches = [
1334+
torch.tensor([], dtype=torch.float32, device=self.device)
1335+
for _ in range(num_layers)
1336+
]
1337+
finished_requests_ids = [seq.request_id for seq in seqs]
1338+
model_input = self.prepare_model_input(
1339+
seqs, finished_requests_ids=finished_requests_ids)
1340+
intermediate_tensors = None
1341+
if not get_pp_group().is_first_rank:
1342+
intermediate_tensors = \
1343+
self.model.make_empty_intermediate_tensors(
1344+
batch_size=batch_size,
1345+
dtype=self.model_config.dtype,
1346+
device=self.device)
1347+
1348+
self.execute_model(model_input, kv_caches, intermediate_tensors)
1349+
torch.cuda.synchronize()
1350+
return
13351351

13361352
def remove_all_loras(self):
13371353
if not self.lora_manager:

0 commit comments

Comments
 (0)