diff --git a/docs/features/README.md b/docs/features/README.md index 5faf3768f321..684802301a44 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -54,7 +54,7 @@ th:not(:first-child) { | beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | -\* Chunked prefill and prefix caching are only applicable to last-token pooling. +\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention. ^ LoRA is only applicable to the language backbone of multimodal models. ### Feature x Hardware diff --git a/tests/entrypoints/pooling/classify/test_offline.py b/tests/entrypoints/pooling/classify/test_offline.py index 1063c3b6b755..a07fcd372721 100644 --- a/tests/entrypoints/pooling/classify/test_offline.py +++ b/tests/entrypoints/pooling/classify/test_offline.py @@ -61,11 +61,8 @@ def get_outputs(use_activation): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): - # chunked prefill does not support all pooling - err_msg = "pooling_task must be one of.+" - with pytest.raises(ValueError, match=err_msg): - llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) +def test_token_classify(llm: LLM): + llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) def test_score_api(llm: LLM): diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index 6fef68858695..1a6c33b455e6 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): - # token_classify uses ALL pooling, which does not support chunked prefill. task = "token_classify" + input_text = ["This product was excellent and exceeded my expectations"] response = requests.post( server.url_for("pooling"), json={ "model": model_name, - "input": "test", + "input": input_text, "encoding_format": "float", "task": task, }, ) - assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith( - f"Task {task} is not supported" - ) + poolings = PoolingResponse.model_validate(response.json()) + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 8 + assert len(poolings.data[0].data[0]) == 2 @pytest.mark.asyncio diff --git a/tests/entrypoints/pooling/embed/test_offline.py b/tests/entrypoints/pooling/embed/test_offline.py index f5eab4c29ae1..12b47b1a08a8 100644 --- a/tests/entrypoints/pooling/embed/test_offline.py +++ b/tests/entrypoints/pooling/embed/test_offline.py @@ -42,7 +42,7 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): +def test_token_embed(llm: LLM): outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) multi_vector = outputs[0].outputs.data assert multi_vector.shape == (11, 384) diff --git a/tests/entrypoints/pooling/reward/test_offline.py b/tests/entrypoints/pooling/reward/test_offline.py index 0255704cecd9..b061b5514515 100644 --- a/tests/entrypoints/pooling/reward/test_offline.py +++ b/tests/entrypoints/pooling/reward/test_offline.py @@ -36,6 +36,13 @@ def llm(): cleanup_dist_env_and_memory() +@pytest.mark.skip_global_cleanup +def test_config(llm: LLM): + vllm_config = llm.llm_engine.vllm_config + assert vllm_config.cache_config.enable_prefix_caching + assert vllm_config.scheduler_config.enable_chunked_prefill + + def test_pooling_params(llm: LLM): def get_outputs(use_activation): outputs = llm.reward( diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py new file mode 100644 index 000000000000..c259c532220b --- /dev/null +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModel + +from tests.models.utils import check_embeddings_close +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-Embedding-0.6B"], +) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str): + chunk_size = 10 + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + runner="pooling", + max_model_len=128, + max_num_batched_tokens=chunk_size, + enforce_eager=True, + # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + ) + + with hf_runner( + model, + auto_cls=AutoModel, + ) as hf_model: + hf_outputs = [] + for token_prompt in token_prompts: + inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])}) + input_ids = inputs["input_ids"] + output = hf_model.model(input_ids) + hf_outputs.append(output.last_hidden_state.cpu().float()[0]) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + check_embeddings_close( + embeddings_0_lst=hf_output, + embeddings_1_lst=vllm_output, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 0d41b93233d5..488b27e2da0f 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): max_model_len=128, enforce_eager=True, runner="pooling", - enable_chunked_prefill=False, enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( diff --git a/tests/test_config.py b/tests/test_config.py index 019c0d6d8733..203447cd531f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): ( "internlm/internlm2-1_8b-reward", "decoder", - False, - "Pooling models with all pooling does not support chunked prefill.", + True, + "Pooling models with causal attn and all pooling support chunked prefill.", ), ( "BAAI/bge-base-en", @@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported( ( "internlm/internlm2-1_8b-reward", "decoder", - False, - "Pooling models with all pooling does not support prefix caching.", + True, + "Pooling models with causal attn and all pooling support prefix caching.", ), ( "BAAI/bge-base-en", diff --git a/vllm/config/model.py b/vllm/config/model.py index 655b7c995f6d..ae5189ce68d9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1780,20 +1780,22 @@ def is_chunked_prefill_supported(self) -> bool: return False elif attn_type == "decoder": pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["all", "mean", "step", "cls"]: + if pooling_type in ["mean", "step", "cls"]: logger.debug( "Pooling models with %s pooling does not " "support chunked prefill.", pooling_type, ) return False - else: - # pooling_type == "last" + elif pooling_type in ["all", "last"]: logger.debug( - "Pooling models with causal attn and last pooling support " - "chunked prefill." + "Pooling models with causal attn and %s pooling support " + "chunked prefill.", + pooling_type, ) return True + else: + raise ValueError(f"{pooling_type=} not supported.") # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return attn_type != "encoder_decoder" @@ -1817,20 +1819,22 @@ def is_prefix_caching_supported(self) -> bool: return False elif attn_type == "decoder": pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["all", "mean", "step", "cls"]: + if pooling_type in ["mean", "step", "cls"]: logger.debug( "Pooling models with %s pooling does not " "support prefix caching.", pooling_type, ) return False - else: - # pooling_type == "last" + elif pooling_type in ["all", "last"]: logger.debug( - "Pooling models with causal attn and last pooling support " - "prefix caching." + "Pooling models with causal attn and %s pooling support " + "prefix caching.", + pooling_type, ) return True + else: + raise ValueError(f"{pooling_type=} not supported.") # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return False diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 185e03e5f3bd..d1942689d7f5 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -127,14 +127,14 @@ def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: raise NotImplementedError def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: pooling_cursor = pooling_metadata.pooling_cursor return self.forward_all(hidden_states, pooling_cursor) @@ -147,7 +147,7 @@ def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" ) @@ -163,27 +163,65 @@ def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): + def __init__(self): + super().__init__() + + vllm_config = get_current_vllm_config() + self.enable_chunked_prefill = ( + vllm_config.scheduler_config.enable_chunked_prefill + ) + def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify"} def forward_all( - self, - hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: - assert not pooling_cursor.is_partial_prefill(), ( - "partial prefill not supported with ALL pooling" + self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor + ) -> PoolerOutput: + raise NotImplementedError( + "forward_all is not implemented for AllPool. Use forward instead." ) + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooling_cursor = pooling_metadata.pooling_cursor + is_finished = pooling_cursor.is_finished() hidden_states_lst = list( hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) ) - return [hidden_states_lst[i] for i in pooling_cursor.index] + hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index] + + if not self.enable_chunked_prefill: + return hidden_states_lst + + pooling_states = pooling_metadata.pooling_states + + # If chunked_prefill is enabled + # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache + for p, hs_chunk in zip(pooling_states, hidden_states_lst): + p.hidden_states_cache.append(hs_chunk) + + # 2. Once prefill is finished, send hidden_states_cache to PoolerHead + output_list: PoolerOutput = [] + for p, finished in zip(pooling_states, is_finished): + if finished: + hidden_states_cache = p.hidden_states_cache + if len(hidden_states_cache) == 1: + output_list.append(hidden_states_cache[0]) + else: + output_list.append(torch.concat(hidden_states_cache, dim=0)) + p.clean() + else: + output_list.append(None) + + return output_list class MeanPool(PoolingMethod): @@ -194,7 +232,7 @@ def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" ) @@ -399,7 +437,7 @@ def forward( self, pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, - ): + ) -> PoolerOutput: return self.activation(pooled_data) @@ -418,7 +456,7 @@ def forward( self, pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, - ): + ) -> PoolerOutput: if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] @@ -586,8 +624,12 @@ def forward( class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): def forward( - self, pooled_data: torch.Tensor, pooling_param: PoolingParams - ) -> torch.Tensor: + self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams + ) -> PoolerOutput: + # for unfinished chunked prefill + if pooled_data is None: + return None + pooled_data = pooled_data.to(self.head_dtype) # pooled_data shape: [n_tokens, hidden_dimension] @@ -630,9 +672,13 @@ def get_supported_tasks(self) -> Set[PoolingTask]: def forward( self, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | None, pooling_param: PoolingParams, - ) -> torch.Tensor: + ) -> PoolerOutput: + # for unfinished chunked prefill + if hidden_states is None: + return None + hidden_states = hidden_states.to(self.head_dtype) # hidden_states shape: [n_token, hidden_size] @@ -686,17 +732,20 @@ def extract_states( self, hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, - ) -> torch.Tensor | list[torch.Tensor]: + ) -> PoolerOutput: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) prompt_token_ids = pooling_metadata.get_prompt_token_ids() - - pooled_data = list[torch.Tensor]() - pooling_params = pooling_metadata.pooling_params + pooled_data: PoolerOutput = [] for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params ): + # for unfinished chunked prefill + if data is None: + pooled_data.append(data) + continue + step_tag_id = pooling_param.step_tag_id returned_token_ids = pooling_param.returned_token_ids diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 19052c8d49e4..9f34090e3107 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -64,7 +64,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal -from .interfaces_base import default_pooling_type +from .interfaces_base import attn_type logger = init_logger(__name__) @@ -220,7 +220,7 @@ def apply( ) -@default_pooling_type("All") +@attn_type("attention_free") @MULTIMODAL_REGISTRY.register_processor( TerratorchMultiModalProcessor, info=TerratorchProcessingInfo, diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 88ac6b4aeb4b..546eacebf83e 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -89,7 +89,7 @@ def empty_cpu( # [num_reqs, ] # The shape of each element depends on the pooler used -PoolerOutput = torch.Tensor | list[torch.Tensor] +PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None @dataclass diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 9ee588ea44ca..acd1a00e8755 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -17,6 +17,7 @@ class PoolingCursor: first_token_indices_gpu: torch.Tensor last_token_indices_gpu: torch.Tensor prompt_lens_cpu: torch.Tensor + seq_lens_cpu: torch.Tensor num_scheduled_tokens_cpu: torch.Tensor def __getitem__(self, indices: slice): @@ -25,12 +26,25 @@ def __getitem__(self, indices: slice): first_token_indices_gpu=self.first_token_indices_gpu[indices], last_token_indices_gpu=self.last_token_indices_gpu[indices], prompt_lens_cpu=self.prompt_lens_cpu[indices], + seq_lens_cpu=self.seq_lens_cpu[indices], num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], ) def is_partial_prefill(self): return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + def is_finished(self): + return self.prompt_lens_cpu == self.seq_lens_cpu + + +class PoolingStates: + def __init__(self): + # for chunked prefill with ALL pooling + self.hidden_states_cache: list[torch.Tensor] = [] + + def clean(self): + self.hidden_states_cache.clear() + @dataclass class PoolingMetadata: @@ -39,6 +53,7 @@ class PoolingMetadata: prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: torch.Tensor | None pooling_params: list[PoolingParams] + pooling_states: list[PoolingStates] pooling_cursor: PoolingCursor | None = None def __post_init__(self) -> None: @@ -60,6 +75,7 @@ def __getitem__(self, indices: slice): if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_states=self.pooling_states[indices], pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices], @@ -74,15 +90,21 @@ def get_prompt_token_ids(self) -> list[torch.Tensor]: return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] def build_pooling_cursor( - self, num_scheduled_tokens: list[int], device: torch.device + self, + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + device: torch.device, ): self.pooling_cursor = build_pooling_cursor( - num_scheduled_tokens, self.prompt_lens, device + num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device ) def build_pooling_cursor( - num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + prompt_lens: torch.Tensor, + device: torch.device, ): assert len(prompt_lens) == len(num_scheduled_tokens) @@ -99,5 +121,6 @@ def build_pooling_cursor( first_token_indices_gpu=cumsum[:n_seq], last_token_indices_gpu=cumsum[1:] - 1, prompt_lens_cpu=prompt_lens, + seq_lens_cpu=seq_lens_cpu, num_scheduled_tokens_cpu=num_scheduled_tokens_cpu, ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 516c76a5e4b1..ead7a3619dea 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -15,7 +15,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import ( BatchUpdateBuilder, LogitsProcessors, @@ -33,7 +33,6 @@ class CachedRequestState: prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] sampling_params: SamplingParams | None - pooling_params: PoolingParams | None generator: torch.Generator | None block_ids: tuple[list[int], ...] @@ -51,11 +50,18 @@ class CachedRequestState: # Used when both async_scheduling and spec_decode are enabled. prev_num_draft_len: int = 0 + # for pooling models + pooling_params: PoolingParams | None = None + pooling_states: PoolingStates | None = None + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds ) + if self.pooling_params is not None: + self.pooling_states = PoolingStates() + @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) @@ -255,7 +261,9 @@ def __init__( # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() + # for pooling models self.pooling_params: dict[str, PoolingParams] = {} + self.pooling_states: dict[str, PoolingStates] = {} # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: torch.Tensor | None = None @@ -413,7 +421,11 @@ def add_request( sampling_params.bad_words_token_ids ) elif pooling_params := request.pooling_params: + pooling_states = request.pooling_states + assert pooling_states is not None + self.pooling_params[req_id] = pooling_params + self.pooling_states[req_id] = pooling_states self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids ) @@ -469,6 +481,7 @@ def remove_request(self, req_id: str) -> int | None: if self.is_pooling_model: self.pooling_params.pop(req_id, None) + self.pooling_states.pop(req_id, None) return req_index self.greedy_reqs.discard(req_id) @@ -837,13 +850,19 @@ def get_pooling_params(self) -> list[PoolingParams]: assert len(self.req_ids) == len(self.pooling_params) return [self.pooling_params[req_id] for req_id in self.req_ids] + def get_pooling_states(self) -> list[PoolingStates]: + assert len(self.req_ids) == len(self.pooling_states) + return [self.pooling_states[req_id] for req_id in self.req_ids] + def get_pooling_metadata(self) -> PoolingMetadata: pooling_params = self.get_pooling_params() + pooling_states = self.get_pooling_states() return PoolingMetadata( prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, + pooling_states=pooling_states, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3f043e3b2648..a7eb9cdae8b1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -131,7 +131,7 @@ SamplerOutput, make_empty_encoder_model_runner_output, ) -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata @@ -2291,20 +2291,6 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if self.scheduler_config.enable_chunked_prefill: - if "token_embed" in supported_tasks: - supported_tasks.remove("token_embed") - if "token_classify" in supported_tasks: - supported_tasks.remove("token_classify") - - logger.debug_once( - "Chunked prefill is not supported with " - "token_embed and token_classify tasks " - "which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it." - ) - if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: @@ -2381,11 +2367,12 @@ def _pool( ) hidden_states = hidden_states[:num_scheduled_tokens] + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), device=hidden_states.device + num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device ) - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -2393,7 +2380,7 @@ def _pool( pooling_metadata=pooling_metadata, ) raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True), + lambda x: x.to("cpu", non_blocking=True) if x is not None else x, raw_pooler_output, ) self._sync_device() @@ -4248,10 +4235,13 @@ def _dummy_pooler_run_task( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, pooling_params=[dummy_pooling_params] * num_reqs, + pooling_states=[PoolingStates() for i in range(num_reqs)], ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, device=hidden_states.device + num_scheduled_tokens_list, + seq_lens_cpu=dummy_prompt_lens, + device=hidden_states.device, ) try: @@ -4278,22 +4268,12 @@ def _dummy_pooler_run( supported_pooling_tasks = self.get_supported_pooling_tasks() if not supported_pooling_tasks: - if self.scheduler_config.enable_chunked_prefill: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks with chunked prefill enabled. " - "Please add --no-enable-chunked-prefill to your " - "config or CLI args. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) - else: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) output_size = dict[PoolingTask, float]() for task in supported_pooling_tasks: