From 1e12d9c113ea5afb5179c814c4d99811794d8be1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 6 Mar 2025 21:04:13 +0800 Subject: [PATCH 1/4] fix and decouple encoder profiling Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/multimodal/profiling.py | 115 +++++++++++++++++++++++------------ 1 file changed, 76 insertions(+), 39 deletions(-) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 3178b0f8c3e6..37816f0654fb 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Generic, TypeVar +from typing import Generic, TypeVar, cast import numpy as np import numpy.typing as npt @@ -13,7 +13,8 @@ from vllm.inputs import DummyData from vllm.logger import init_logger -from .inputs import MultiModalDataDict, MultiModalInputs +from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalInputs) from .processing import BaseMultiModalProcessor, BaseProcessingInfo logger = init_logger(__name__) @@ -144,13 +145,72 @@ def _get_dummy_mm_inputs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) + def create_encoder_dummy_data( + self, + seq_len: int, + mm_inputs: MultiModalEncDecInputs, + total_placeholders_by_modality: Mapping[str, int], + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + # For encoder-decoder models, use encoder prompt token ids instead of + # decoder prompt to construct dummy seq_data for encoder profiling. + encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"] + encoder_prompt_token_ids.extend( + [0] * (seq_len - len(encoder_prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(encoder_prompt_token_ids), + multi_modal_data=None, + multi_modal_placeholders=None, + ) + + def create_decoder_dummy_data( + self, + seq_len: int, + mm_inputs: MultiModalInputs, + total_placeholders_by_modality: Mapping[str, int], + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + prompt_token_ids = mm_inputs["prompt_token_ids"] + total_len = len(prompt_token_ids) + + # V0 does not support chunked prefill. + if total_len > seq_len and not envs.VLLM_USE_V1: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain " + "multi-modal inputs to fail during inference, even when " + "the input text is short. To avoid this, you should " + "increase `max_model_len`, reduce `max_num_seqs`, " + "and/or reduce `mm_counts`.", seq_len, total_len, + total_placeholders_by_modality) + + return DummyData( + seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + multi_modal_data=None, + multi_modal_placeholders=None, + ) + + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=mm_inputs["mm_kwargs"], + multi_modal_placeholders=mm_inputs["mm_placeholders"], + ) + def get_dummy_data( self, seq_len: int, is_encoder_data: bool = False, ) -> DummyData: # Avoid circular import - from vllm.sequence import SequenceData mm_counts = self.get_mm_limits() @@ -167,11 +227,6 @@ def get_dummy_data( mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) placeholders_by_modality = mm_inputs["mm_placeholders"] - # For encoder-decoder models, use encoder prompt token ids instead of - # decoder prompt to construct dummy seq_data for encoder profiling. - prompt_token_ids = ( - mm_inputs["prompt_token_ids"] if not is_encoder_data else - mm_inputs["encoder_prompt_token_ids"]) # type: ignore total_placeholders_by_modality = { modality: sum(item["length"] for item in placeholders) @@ -188,35 +243,17 @@ def get_dummy_data( f"is not the expected {expected_placeholders_by_modality} " "tokens.") - total_len = len(prompt_token_ids) - - # V0 does not support chunked prefill. - if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: - if total_len > seq_len and not is_encoder_data: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain " - "multi-modal inputs to fail during inference, even when " - "the input text is short. To avoid this, you should " - "increase `max_model_len`, reduce `max_num_seqs`, " - "and/or reduce `mm_counts`.", seq_len, total_len, - total_placeholders_by_modality) - - num_tokens_to_pad = max(total_len, seq_len) - total_len - prompt_token_ids.extend([0] * num_tokens_to_pad) - - return DummyData( - seq_data=SequenceData.from_seqs(prompt_token_ids), - multi_modal_data=None, - multi_modal_placeholders=None, + if is_encoder_data: + mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) + dummy_data = self.create_encoder_dummy_data( + seq_len, + mm_inputs, + total_placeholders_by_modality, ) - - prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) - - return DummyData( - seq_data=SequenceData.from_seqs(prompt_token_ids), - multi_modal_data=mm_inputs["mm_kwargs"], - multi_modal_placeholders=placeholders_by_modality, - ) + else: + dummy_data = self.create_decoder_dummy_data( + seq_len, + mm_inputs, + total_placeholders_by_modality, + ) + return dummy_data From 591eeec7654b2d7dbe5c375426f87f9f6364285f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 6 Mar 2025 21:12:20 +0800 Subject: [PATCH 2/4] clean up Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/multimodal/profiling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 37816f0654fb..5e2304b0625f 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -210,7 +210,6 @@ def get_dummy_data( seq_len: int, is_encoder_data: bool = False, ) -> DummyData: - # Avoid circular import mm_counts = self.get_mm_limits() From f84a559dc38856f7528fb39a38efac12e17f771b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 7 Mar 2025 22:39:13 +0800 Subject: [PATCH 3/4] refactor Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/inputs/registry.py | 6 +- vllm/multimodal/profiling.py | 108 ++++++++++++++++------------------- 2 files changed, 52 insertions(+), 62 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index babfc4fb809c..a89be6e88821 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -335,8 +335,10 @@ def dummy_data_for_profiling( tokenizer, disable_cache=True) profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_dummy_data( - seq_len, is_encoder_data=is_encoder_data) + dummy_data_factory = (profiler.get_encoder_dummy_data + if is_encoder_data else + profiler.get_decoder_dummy_data) + dummy_data = dummy_data_factory(seq_len) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 5e2304b0625f..77420563d2a0 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -145,20 +145,59 @@ def _get_dummy_mm_inputs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) - def create_encoder_dummy_data( + def get_and_validate_mm_inputs( + self, + seq_len: int, + ) -> tuple[MultiModalInputs, Mapping[str, int]]: + mm_counts = self.get_mm_limits() + + info = self.processing_info + mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( + seq_len, mm_counts) + + if mm_counts.keys() != mm_max_tokens_per_item.keys(): + raise AssertionError( + "The keys returned by `get_supported_mm_limits` " + f"({set(mm_counts.keys())}) should be the same as those " + "returned by `get_mm_max_tokens_per_item` " + f"({set(mm_max_tokens_per_item.keys())})") + + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + placeholders_by_modality = mm_inputs["mm_placeholders"] + + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens_per_item[modality] * mm_counts[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") + return mm_inputs, total_placeholders_by_modality + + def get_encoder_dummy_data( self, seq_len: int, - mm_inputs: MultiModalEncDecInputs, - total_placeholders_by_modality: Mapping[str, int], ) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData + mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len) + mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) + # For encoder-decoder models, use encoder prompt token ids instead of # decoder prompt to construct dummy seq_data for encoder profiling. encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"] - encoder_prompt_token_ids.extend( - [0] * (seq_len - len(encoder_prompt_token_ids))) + + total_len = len(encoder_prompt_token_ids) + num_tokens_to_pad = max(total_len, seq_len) - total_len + encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) return DummyData( seq_data=SequenceData.from_seqs(encoder_prompt_token_ids), @@ -166,15 +205,16 @@ def create_encoder_dummy_data( multi_modal_placeholders=None, ) - def create_decoder_dummy_data( + def get_decoder_dummy_data( self, seq_len: int, - mm_inputs: MultiModalInputs, - total_placeholders_by_modality: Mapping[str, int], ) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData + (mm_inputs, total_placeholders_by_modality + ) = self.get_and_validate_mm_inputs(seq_len) + prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) @@ -204,55 +244,3 @@ def create_decoder_dummy_data( multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_placeholders=mm_inputs["mm_placeholders"], ) - - def get_dummy_data( - self, - seq_len: int, - is_encoder_data: bool = False, - ) -> DummyData: - - mm_counts = self.get_mm_limits() - - info = self.processing_info - mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( - seq_len, mm_counts) - - if mm_counts.keys() != mm_max_tokens_per_item.keys(): - raise AssertionError( - "The keys returned by `get_supported_mm_limits` " - f"({set(mm_counts.keys())}) should be the same as those " - "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - placeholders_by_modality = mm_inputs["mm_placeholders"] - - total_placeholders_by_modality = { - modality: sum(item["length"] for item in placeholders) - for modality, placeholders in placeholders_by_modality.items() - } - expected_placeholders_by_modality = { - modality: mm_max_tokens_per_item[modality] * mm_counts[modality] - for modality in placeholders_by_modality - } - if total_placeholders_by_modality != expected_placeholders_by_modality: - raise AssertionError( - f"The processed dummy data has a total of " - f"{total_placeholders_by_modality} placeholder tokens, which " - f"is not the expected {expected_placeholders_by_modality} " - "tokens.") - - if is_encoder_data: - mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) - dummy_data = self.create_encoder_dummy_data( - seq_len, - mm_inputs, - total_placeholders_by_modality, - ) - else: - dummy_data = self.create_decoder_dummy_data( - seq_len, - mm_inputs, - total_placeholders_by_modality, - ) - return dummy_data From 78b7a15a88755761c720f2819e91c785c1a15708 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 8 Mar 2025 22:42:54 +0800 Subject: [PATCH 4/4] fix processing test Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/multimodal/test_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index ba3df86f715a..a358eee5ddb6 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): exc_ctx = pytest.raises(ValueError, match="this model only supports") with exc_ctx: - profiler.get_dummy_data(model_config.max_model_len) + profiler.get_decoder_dummy_data(model_config.max_model_len) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])