Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
86c0f38
Support chunked prefill with ALL pooling
noooop Oct 18, 2025
6bd49f2
fix
noooop Oct 18, 2025
44c6ee1
fix
noooop Oct 18, 2025
86f0868
fix
noooop Oct 18, 2025
7c1d68d
Update vllm/model_executor/layers/pooler.py
noooop Oct 18, 2025
f903415
Update vllm/model_executor/layers/pooler.py
noooop Oct 18, 2025
72df85d
Update vllm/model_executor/layers/pooler.py
noooop Oct 18, 2025
6b6e7a8
fix deep copy
noooop Oct 18, 2025
9aef354
fix
noooop Oct 18, 2025
d574b6c
+ tests
noooop Oct 18, 2025
26351d7
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Oct 18, 2025
5c4b13c
fix
noooop Oct 18, 2025
178ccd2
fix StepPooler
noooop Oct 18, 2025
43291db
fix StepPooler
noooop Oct 18, 2025
bb9a4ad
+ preempted_req
noooop Oct 21, 2025
08a0739
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Oct 21, 2025
29b3d1d
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Oct 28, 2025
eea5f6c
update
noooop Oct 28, 2025
41ff486
update
noooop Oct 28, 2025
70c0965
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Nov 17, 2025
e8f222e
update
noooop Nov 17, 2025
d78b2cf
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Nov 17, 2025
fb8197b
update
noooop Nov 17, 2025
c985ced
conflicts
noooop Dec 1, 2025
0235532
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 1, 2025
2d7f0c7
update
noooop Dec 1, 2025
193e049
update
noooop Dec 1, 2025
0da5b30
fix
noooop Dec 1, 2025
da00d3c
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 1, 2025
dc0127c
fix Terratorch
noooop Dec 1, 2025
bcc66e5
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 2, 2025
1351cd6
conflicts
noooop Dec 2, 2025
2e1168b
+ docs
noooop Dec 2, 2025
bbe7113
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 2, 2025
cc92d0b
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 4, 2025
3373b90
+ PoolingParamsInternalStates
noooop Dec 4, 2025
42d2f49
+ PoolingStates
noooop Dec 4, 2025
e921aed
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 4, 2025
011295c
fix
noooop Dec 4, 2025
3ea3144
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 4, 2025
3716b1e
+ CachedRequestState.pooling_states
noooop Dec 4, 2025
241b60b
- Unnecessary modification
noooop Dec 4, 2025
367d6eb
fix
noooop Dec 4, 2025
65c7910
num_scheduled_tokens_cpu
noooop Dec 4, 2025
63a4782
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 4, 2025
43bf1ef
update
noooop Dec 4, 2025
f7f320e
Merge branch 'main' into all_pooling_plus_chunked_prefill2
noooop Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/features/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ th:not(:first-child) {
| beam-search |||| [](https://github.com/vllm-project/vllm/issues/6137) ||||||| [](https://github.com/vllm-project/vllm/issues/7968) |||| |
| [prompt-embeds](prompt_embeds.md) ||||||||||||||||

\* Chunked prefill and prefix caching are only applicable to last-token pooling.
\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.

### Feature x Hardware
Expand Down
7 changes: 2 additions & 5 deletions tests/entrypoints/pooling/classify/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,8 @@ def get_outputs(use_activation):


@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_token_classify(llm: LLM):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)


def test_score_api(llm: LLM):
Expand Down
12 changes: 6 additions & 6 deletions tests/entrypoints/pooling/classify/test_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/pooling/embed/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def llm():


@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
Expand Down
7 changes: 7 additions & 0 deletions tests/entrypoints/pooling/reward/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory()


@pytest.mark.skip_global_cleanup
def test_config(llm: LLM):
vllm_config = llm.llm_engine.vllm_config
assert vllm_config.cache_config.enable_prefix_caching
assert vllm_config.scheduler_config.enable_chunked_prefill


def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
outputs = llm.reward(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel

from tests.models.utils import check_embeddings_close
from vllm import TokensPrompt


@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
chunk_size = 10
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]

with vllm_runner(
model,
runner="pooling",
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
enable_chunked_prefill=True,
enable_prefix_caching=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
)

with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
hf_outputs = []
for token_prompt in token_prompts:
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
input_ids = inputs["input_ids"]
output = hf_model.model(input_ids)
hf_outputs.append(output.last_hidden_state.cpu().float()[0])

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
max_model_len=128,
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support chunked prefill.",
True,
"Pooling models with causal attn and all pooling support chunked prefill.",
),
(
"BAAI/bge-base-en",
Expand Down Expand Up @@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support prefix caching.",
True,
"Pooling models with causal attn and all pooling support prefix caching.",
),
(
"BAAI/bge-base-en",
Expand Down
24 changes: 14 additions & 10 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,20 +1780,22 @@ def is_chunked_prefill_supported(self) -> bool:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"chunked prefill."
"Pooling models with causal attn and %s pooling support "
"chunked prefill.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
Expand All @@ -1817,20 +1819,22 @@ def is_prefix_caching_supported(self) -> bool:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"prefix caching."
"Pooling models with causal attn and %s pooling support "
"prefix caching.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
Expand Down
93 changes: 71 additions & 22 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
raise NotImplementedError

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)

Expand All @@ -147,7 +147,7 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
Expand All @@ -163,27 +163,65 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
return hidden_states[pooling_cursor.last_token_indices_gpu]


class AllPool(PoolingMethod):
def __init__(self):
super().__init__()

vllm_config = get_current_vllm_config()
self.enable_chunked_prefill = (
vllm_config.scheduler_config.enable_chunked_prefill
)

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}

def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with ALL pooling"
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> PoolerOutput:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
)
return [hidden_states_lst[i] for i in pooling_cursor.index]
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]

if not self.enable_chunked_prefill:
return hidden_states_lst

pooling_states = pooling_metadata.pooling_states

# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)

# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = []
for p, finished in zip(pooling_states, is_finished):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)

return output_list


class MeanPool(PoolingMethod):
Expand All @@ -194,7 +232,7 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
Expand Down Expand Up @@ -399,7 +437,7 @@ def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
) -> PoolerOutput:
return self.activation(pooled_data)


Expand All @@ -418,7 +456,7 @@ def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
) -> PoolerOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
Expand Down Expand Up @@ -586,8 +624,12 @@ def forward(

class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> PoolerOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None

pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]

Expand Down Expand Up @@ -630,9 +672,13 @@ def get_supported_tasks(self) -> Set[PoolingTask]:

def forward(
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor | None,
pooling_param: PoolingParams,
) -> torch.Tensor:
) -> PoolerOutput:
# for unfinished chunked prefill
if hidden_states is None:
return None

hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]

Expand Down Expand Up @@ -686,17 +732,20 @@ def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
) -> PoolerOutput:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()

pooled_data = list[torch.Tensor]()

pooling_params = pooling_metadata.pooling_params

pooled_data: PoolerOutput = []
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pooled_data.append(data)
continue

step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids

Expand Down
Loading