diff --git a/docs/features/README.md b/docs/features/README.md
index 5faf3768f321..684802301a44 100644
--- a/docs/features/README.md
+++ b/docs/features/README.md
@@ -54,7 +54,7 @@ th:not(:first-child) {
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
-\* Chunked prefill and prefix caching are only applicable to last-token pooling.
+\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
^ LoRA is only applicable to the language backbone of multimodal models.
### Feature x Hardware
diff --git a/tests/entrypoints/pooling/classify/test_offline.py b/tests/entrypoints/pooling/classify/test_offline.py
index 1063c3b6b755..a07fcd372721 100644
--- a/tests/entrypoints/pooling/classify/test_offline.py
+++ b/tests/entrypoints/pooling/classify/test_offline.py
@@ -61,11 +61,8 @@ def get_outputs(use_activation):
@pytest.mark.skip_global_cleanup
-def test_encode_api(llm: LLM):
- # chunked prefill does not support all pooling
- err_msg = "pooling_task must be one of.+"
- with pytest.raises(ValueError, match=err_msg):
- llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
+def test_token_classify(llm: LLM):
+ llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM):
diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py
index 6fef68858695..1a6c33b455e6 100644
--- a/tests/entrypoints/pooling/classify/test_online.py
+++ b/tests/entrypoints/pooling/classify/test_online.py
@@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
- # token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify"
+ input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
- "input": "test",
+ "input": input_text,
"encoding_format": "float",
"task": task,
},
)
- assert response.json()["error"]["type"] == "BadRequestError"
- assert response.json()["error"]["message"].startswith(
- f"Task {task} is not supported"
- )
+ poolings = PoolingResponse.model_validate(response.json())
+ assert len(poolings.data) == 1
+ assert len(poolings.data[0].data) == 8
+ assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio
diff --git a/tests/entrypoints/pooling/embed/test_offline.py b/tests/entrypoints/pooling/embed/test_offline.py
index f5eab4c29ae1..12b47b1a08a8 100644
--- a/tests/entrypoints/pooling/embed/test_offline.py
+++ b/tests/entrypoints/pooling/embed/test_offline.py
@@ -42,7 +42,7 @@ def llm():
@pytest.mark.skip_global_cleanup
-def test_encode_api(llm: LLM):
+def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
diff --git a/tests/entrypoints/pooling/reward/test_offline.py b/tests/entrypoints/pooling/reward/test_offline.py
index 0255704cecd9..b061b5514515 100644
--- a/tests/entrypoints/pooling/reward/test_offline.py
+++ b/tests/entrypoints/pooling/reward/test_offline.py
@@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory()
+@pytest.mark.skip_global_cleanup
+def test_config(llm: LLM):
+ vllm_config = llm.llm_engine.vllm_config
+ assert vllm_config.cache_config.enable_prefix_caching
+ assert vllm_config.scheduler_config.enable_chunked_prefill
+
+
def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
outputs = llm.reward(
diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py
new file mode 100644
index 000000000000..c259c532220b
--- /dev/null
+++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py
@@ -0,0 +1,53 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import pytest
+import torch
+from transformers import AutoModel
+
+from tests.models.utils import check_embeddings_close
+from vllm import TokensPrompt
+
+
+@pytest.mark.parametrize(
+ "model",
+ ["Qwen/Qwen3-Embedding-0.6B"],
+)
+@torch.inference_mode
+def test_embed_models(hf_runner, vllm_runner, model: str):
+ chunk_size = 10
+ n_prompt_tokens = [55, 56, 57]
+ token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
+
+ with vllm_runner(
+ model,
+ runner="pooling",
+ max_model_len=128,
+ max_num_batched_tokens=chunk_size,
+ enforce_eager=True,
+ # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
+ enable_chunked_prefill=True,
+ enable_prefix_caching=True,
+ ) as vllm_model:
+ vllm_outputs = vllm_model.token_embed(
+ [TokensPrompt(prompt_token_ids=t) for t in token_prompts],
+ )
+
+ with hf_runner(
+ model,
+ auto_cls=AutoModel,
+ ) as hf_model:
+ hf_outputs = []
+ for token_prompt in token_prompts:
+ inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
+ input_ids = inputs["input_ids"]
+ output = hf_model.model(input_ids)
+ hf_outputs.append(output.last_hidden_state.cpu().float()[0])
+
+ for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
+ check_embeddings_close(
+ embeddings_0_lst=hf_output,
+ embeddings_1_lst=vllm_output,
+ name_0="hf",
+ name_1="vllm",
+ tol=1e-2,
+ )
diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py
index 0d41b93233d5..488b27e2da0f 100644
--- a/tests/models/language/pooling/test_extract_hidden_states.py
+++ b/tests/models/language/pooling/test_extract_hidden_states.py
@@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
max_model_len=128,
enforce_eager=True,
runner="pooling",
- enable_chunked_prefill=False,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
diff --git a/tests/test_config.py b/tests/test_config.py
index 019c0d6d8733..203447cd531f 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
(
"internlm/internlm2-1_8b-reward",
"decoder",
- False,
- "Pooling models with all pooling does not support chunked prefill.",
+ True,
+ "Pooling models with causal attn and all pooling support chunked prefill.",
),
(
"BAAI/bge-base-en",
@@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
(
"internlm/internlm2-1_8b-reward",
"decoder",
- False,
- "Pooling models with all pooling does not support prefix caching.",
+ True,
+ "Pooling models with causal attn and all pooling support prefix caching.",
),
(
"BAAI/bge-base-en",
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 655b7c995f6d..ae5189ce68d9 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -1780,20 +1780,22 @@ def is_chunked_prefill_supported(self) -> bool:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
- if pooling_type in ["all", "mean", "step", "cls"]:
+ if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
- else:
- # pooling_type == "last"
+ elif pooling_type in ["all", "last"]:
logger.debug(
- "Pooling models with causal attn and last pooling support "
- "chunked prefill."
+ "Pooling models with causal attn and %s pooling support "
+ "chunked prefill.",
+ pooling_type,
)
return True
+ else:
+ raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
@@ -1817,20 +1819,22 @@ def is_prefix_caching_supported(self) -> bool:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
- if pooling_type in ["all", "mean", "step", "cls"]:
+ if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
- else:
- # pooling_type == "last"
+ elif pooling_type in ["all", "last"]:
logger.debug(
- "Pooling models with causal attn and last pooling support "
- "prefix caching."
+ "Pooling models with causal attn and %s pooling support "
+ "prefix caching.",
+ pooling_type,
)
return True
+ else:
+ raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py
index 185e03e5f3bd..d1942689d7f5 100644
--- a/vllm/model_executor/layers/pooler.py
+++ b/vllm/model_executor/layers/pooler.py
@@ -127,14 +127,14 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
- ) -> list[torch.Tensor] | torch.Tensor:
+ ) -> PoolerOutput:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
- ) -> list[torch.Tensor] | torch.Tensor:
+ ) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
@@ -147,7 +147,7 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
- ) -> list[torch.Tensor] | torch.Tensor:
+ ) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
@@ -163,27 +163,65 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
- ) -> list[torch.Tensor] | torch.Tensor:
+ ) -> PoolerOutput:
return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod):
+ def __init__(self):
+ super().__init__()
+
+ vllm_config = get_current_vllm_config()
+ self.enable_chunked_prefill = (
+ vllm_config.scheduler_config.enable_chunked_prefill
+ )
+
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward_all(
- self,
- hidden_states: torch.Tensor,
- pooling_cursor: PoolingCursor,
- ) -> list[torch.Tensor] | torch.Tensor:
- assert not pooling_cursor.is_partial_prefill(), (
- "partial prefill not supported with ALL pooling"
+ self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
+ ) -> PoolerOutput:
+ raise NotImplementedError(
+ "forward_all is not implemented for AllPool. Use forward instead."
)
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ pooling_metadata: PoolingMetadata,
+ ) -> PoolerOutput:
+ pooling_cursor = pooling_metadata.pooling_cursor
+ is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
)
- return [hidden_states_lst[i] for i in pooling_cursor.index]
+ hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
+
+ if not self.enable_chunked_prefill:
+ return hidden_states_lst
+
+ pooling_states = pooling_metadata.pooling_states
+
+ # If chunked_prefill is enabled
+ # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
+ for p, hs_chunk in zip(pooling_states, hidden_states_lst):
+ p.hidden_states_cache.append(hs_chunk)
+
+ # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
+ output_list: PoolerOutput = []
+ for p, finished in zip(pooling_states, is_finished):
+ if finished:
+ hidden_states_cache = p.hidden_states_cache
+ if len(hidden_states_cache) == 1:
+ output_list.append(hidden_states_cache[0])
+ else:
+ output_list.append(torch.concat(hidden_states_cache, dim=0))
+ p.clean()
+ else:
+ output_list.append(None)
+
+ return output_list
class MeanPool(PoolingMethod):
@@ -194,7 +232,7 @@ def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
- ) -> list[torch.Tensor] | torch.Tensor:
+ ) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
@@ -399,7 +437,7 @@ def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
- ):
+ ) -> PoolerOutput:
return self.activation(pooled_data)
@@ -418,7 +456,7 @@ def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
- ):
+ ) -> PoolerOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
@@ -586,8 +624,12 @@ def forward(
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
- self, pooled_data: torch.Tensor, pooling_param: PoolingParams
- ) -> torch.Tensor:
+ self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
+ ) -> PoolerOutput:
+ # for unfinished chunked prefill
+ if pooled_data is None:
+ return None
+
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
@@ -630,9 +672,13 @@ def get_supported_tasks(self) -> Set[PoolingTask]:
def forward(
self,
- hidden_states: torch.Tensor,
+ hidden_states: torch.Tensor | None,
pooling_param: PoolingParams,
- ) -> torch.Tensor:
+ ) -> PoolerOutput:
+ # for unfinished chunked prefill
+ if hidden_states is None:
+ return None
+
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
@@ -686,17 +732,20 @@ def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
- ) -> torch.Tensor | list[torch.Tensor]:
+ ) -> PoolerOutput:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
-
- pooled_data = list[torch.Tensor]()
-
pooling_params = pooling_metadata.pooling_params
+ pooled_data: PoolerOutput = []
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
+ # for unfinished chunked prefill
+ if data is None:
+ pooled_data.append(data)
+ continue
+
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py
index 19052c8d49e4..9f34090e3107 100644
--- a/vllm/model_executor/models/terratorch.py
+++ b/vllm/model_executor/models/terratorch.py
@@ -64,7 +64,7 @@
from vllm.sequence import IntermediateTensors
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
-from .interfaces_base import default_pooling_type
+from .interfaces_base import attn_type
logger = init_logger(__name__)
@@ -220,7 +220,7 @@ def apply(
)
-@default_pooling_type("All")
+@attn_type("attention_free")
@MULTIMODAL_REGISTRY.register_processor(
TerratorchMultiModalProcessor,
info=TerratorchProcessingInfo,
diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
index 88ac6b4aeb4b..546eacebf83e 100644
--- a/vllm/v1/outputs.py
+++ b/vllm/v1/outputs.py
@@ -89,7 +89,7 @@ def empty_cpu(
# [num_reqs, ]
# The shape of each element depends on the pooler used
-PoolerOutput = torch.Tensor | list[torch.Tensor]
+PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
@dataclass
diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py
index 9ee588ea44ca..acd1a00e8755 100644
--- a/vllm/v1/pool/metadata.py
+++ b/vllm/v1/pool/metadata.py
@@ -17,6 +17,7 @@ class PoolingCursor:
first_token_indices_gpu: torch.Tensor
last_token_indices_gpu: torch.Tensor
prompt_lens_cpu: torch.Tensor
+ seq_lens_cpu: torch.Tensor
num_scheduled_tokens_cpu: torch.Tensor
def __getitem__(self, indices: slice):
@@ -25,12 +26,25 @@ def __getitem__(self, indices: slice):
first_token_indices_gpu=self.first_token_indices_gpu[indices],
last_token_indices_gpu=self.last_token_indices_gpu[indices],
prompt_lens_cpu=self.prompt_lens_cpu[indices],
+ seq_lens_cpu=self.seq_lens_cpu[indices],
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
)
def is_partial_prefill(self):
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
+ def is_finished(self):
+ return self.prompt_lens_cpu == self.seq_lens_cpu
+
+
+class PoolingStates:
+ def __init__(self):
+ # for chunked prefill with ALL pooling
+ self.hidden_states_cache: list[torch.Tensor] = []
+
+ def clean(self):
+ self.hidden_states_cache.clear()
+
@dataclass
class PoolingMetadata:
@@ -39,6 +53,7 @@ class PoolingMetadata:
prompt_lens: torch.Tensor # CPU Tensor
prompt_token_ids: torch.Tensor | None
pooling_params: list[PoolingParams]
+ pooling_states: list[PoolingStates]
pooling_cursor: PoolingCursor | None = None
def __post_init__(self) -> None:
@@ -60,6 +75,7 @@ def __getitem__(self, indices: slice):
if self.prompt_token_ids is None
else self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
+ pooling_states=self.pooling_states[indices],
pooling_cursor=None
if self.pooling_cursor is None
else self.pooling_cursor[indices],
@@ -74,15 +90,21 @@ def get_prompt_token_ids(self) -> list[torch.Tensor]:
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def build_pooling_cursor(
- self, num_scheduled_tokens: list[int], device: torch.device
+ self,
+ num_scheduled_tokens: list[int],
+ seq_lens_cpu: torch.Tensor,
+ device: torch.device,
):
self.pooling_cursor = build_pooling_cursor(
- num_scheduled_tokens, self.prompt_lens, device
+ num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device
)
def build_pooling_cursor(
- num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device
+ num_scheduled_tokens: list[int],
+ seq_lens_cpu: torch.Tensor,
+ prompt_lens: torch.Tensor,
+ device: torch.device,
):
assert len(prompt_lens) == len(num_scheduled_tokens)
@@ -99,5 +121,6 @@ def build_pooling_cursor(
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
+ seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
)
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 516c76a5e4b1..ead7a3619dea 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -15,7 +15,7 @@
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
-from vllm.v1.pool.metadata import PoolingMetadata
+from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import (
BatchUpdateBuilder,
LogitsProcessors,
@@ -33,7 +33,6 @@ class CachedRequestState:
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
- pooling_params: PoolingParams | None
generator: torch.Generator | None
block_ids: tuple[list[int], ...]
@@ -51,11 +50,18 @@ class CachedRequestState:
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
+ # for pooling models
+ pooling_params: PoolingParams | None = None
+ pooling_states: PoolingStates | None = None
+
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
)
+ if self.pooling_params is not None:
+ self.pooling_states = PoolingStates()
+
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
@@ -255,7 +261,9 @@ def __init__(
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
+ # for pooling models
self.pooling_params: dict[str, PoolingParams] = {}
+ self.pooling_states: dict[str, PoolingStates] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None
@@ -413,7 +421,11 @@ def add_request(
sampling_params.bad_words_token_ids
)
elif pooling_params := request.pooling_params:
+ pooling_states = request.pooling_states
+ assert pooling_states is not None
+
self.pooling_params[req_id] = pooling_params
+ self.pooling_states[req_id] = pooling_states
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids
)
@@ -469,6 +481,7 @@ def remove_request(self, req_id: str) -> int | None:
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
+ self.pooling_states.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
@@ -837,13 +850,19 @@ def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
+ def get_pooling_states(self) -> list[PoolingStates]:
+ assert len(self.req_ids) == len(self.pooling_states)
+ return [self.pooling_states[req_id] for req_id in self.req_ids]
+
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
+ pooling_states = self.get_pooling_states()
return PoolingMetadata(
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
+ pooling_states=pooling_states,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 3f043e3b2648..a7eb9cdae8b1 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -131,7 +131,7 @@
SamplerOutput,
make_empty_encoder_model_runner_output,
)
-from vllm.v1.pool.metadata import PoolingMetadata
+from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
@@ -2291,20 +2291,6 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]:
supported_tasks = list(model.pooler.get_supported_tasks())
- if self.scheduler_config.enable_chunked_prefill:
- if "token_embed" in supported_tasks:
- supported_tasks.remove("token_embed")
- if "token_classify" in supported_tasks:
- supported_tasks.remove("token_classify")
-
- logger.debug_once(
- "Chunked prefill is not supported with "
- "token_embed and token_classify tasks "
- "which using ALL pooling. "
- "Please turn off chunked prefill by "
- "`--no-enable-chunked-prefill` before using it."
- )
-
if "score" in supported_tasks:
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
if num_labels != 1:
@@ -2381,11 +2367,12 @@ def _pool(
)
hidden_states = hidden_states[:num_scheduled_tokens]
+ seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
+
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(
- num_scheduled_tokens_np.tolist(), device=hidden_states.device
+ num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
)
- seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
model = cast(VllmModelForPooling, self.model)
raw_pooler_output: PoolerOutput = model.pooler(
@@ -2393,7 +2380,7 @@ def _pool(
pooling_metadata=pooling_metadata,
)
raw_pooler_output = json_map_leaves(
- lambda x: x.to("cpu", non_blocking=True),
+ lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
raw_pooler_output,
)
self._sync_device()
@@ -4248,10 +4235,13 @@ def _dummy_pooler_run_task(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
+ pooling_states=[PoolingStates() for i in range(num_reqs)],
)
dummy_metadata.build_pooling_cursor(
- num_scheduled_tokens_list, device=hidden_states.device
+ num_scheduled_tokens_list,
+ seq_lens_cpu=dummy_prompt_lens,
+ device=hidden_states.device,
)
try:
@@ -4278,22 +4268,12 @@ def _dummy_pooler_run(
supported_pooling_tasks = self.get_supported_pooling_tasks()
if not supported_pooling_tasks:
- if self.scheduler_config.enable_chunked_prefill:
- raise RuntimeError(
- f"Model {self.model_config.model} does not support "
- "any pooling tasks with chunked prefill enabled. "
- "Please add --no-enable-chunked-prefill to your "
- "config or CLI args. See "
- "https://docs.vllm.ai/en/latest/models/pooling_models.html "
- "to learn more."
- )
- else:
- raise RuntimeError(
- f"Model {self.model_config.model} does not support "
- "any pooling tasks. See "
- "https://docs.vllm.ai/en/latest/models/pooling_models.html "
- "to learn more."
- )
+ raise RuntimeError(
+ f"Model {self.model_config.model} does not support "
+ "any pooling tasks. See "
+ "https://docs.vllm.ai/en/latest/models/pooling_models.html "
+ "to learn more."
+ )
output_size = dict[PoolingTask, float]()
for task in supported_pooling_tasks: