From b9272cb4e4fad60e711b5b6c58a3785d96440ae8 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Sat, 8 Mar 2025 17:33:05 -0500 Subject: [PATCH] =?UTF-8?q?Revert=20"[Bugfix]=20Fix=20profiling=20OOM=20an?= =?UTF-8?q?d=20decouple=20encoder=20multimodal=20profiling=20=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 609ef61fea63c8afe2dbe60bbeb6badff29a6156. --- tests/multimodal/test_processing.py | 2 +- vllm/inputs/registry.py | 6 +-- vllm/multimodal/profiling.py | 84 +++++++++++------------------ 3 files changed, 33 insertions(+), 59 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index a358eee5ddb6..ba3df86f715a 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): exc_ctx = pytest.raises(ValueError, match="this model only supports") with exc_ctx: - profiler.get_decoder_dummy_data(model_config.max_model_len) + profiler.get_dummy_data(model_config.max_model_len) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index a0bd8f278fd0..32d7a8b3dd7b 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -336,10 +336,8 @@ def dummy_data_for_profiling( tokenizer, disable_cache=True) profiler = MultiModalProfiler(processor) - dummy_data_factory = (profiler.get_encoder_dummy_data - if is_encoder_data else - profiler.get_decoder_dummy_data) - dummy_data = dummy_data_factory(seq_len) + dummy_data = profiler.get_dummy_data( + seq_len, is_encoder_data=is_encoder_data) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index b791fb83478f..57f045dae6bd 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Generic, TypeVar, cast +from typing import Generic, TypeVar import numpy as np import numpy.typing as npt @@ -13,8 +13,7 @@ from vllm.inputs import DummyData from vllm.logger import init_logger -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs) +from .inputs import MultiModalDataDict, MultiModalInputs from .processing import BaseMultiModalProcessor, BaseProcessingInfo logger = init_logger(__name__) @@ -143,10 +142,14 @@ def _get_dummy_mm_inputs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) - def get_and_validate_mm_inputs( + def get_dummy_data( self, seq_len: int, - ) -> tuple[MultiModalInputs, Mapping[str, int]]: + is_encoder_data: bool = False, + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + mm_counts = self.get_mm_limits() info = self.processing_info @@ -162,6 +165,11 @@ def get_and_validate_mm_inputs( mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) placeholders_by_modality = mm_inputs["mm_placeholders"] + # For encoder-decoder models, use encoder prompt token ids instead of + # decoder prompt to construct dummy seq_data for encoder profiling. + prompt_token_ids = ( + mm_inputs["prompt_token_ids"] if not is_encoder_data else + mm_inputs["encoder_prompt_token_ids"]) # type: ignore total_placeholders_by_modality = { modality: sum(item["length"] for item in placeholders) @@ -177,60 +185,28 @@ def get_and_validate_mm_inputs( f"{total_placeholders_by_modality} placeholder tokens, which " f"is not the expected {expected_placeholders_by_modality} " "tokens.") - return mm_inputs, total_placeholders_by_modality - - def get_encoder_dummy_data( - self, - seq_len: int, - ) -> DummyData: - # Avoid circular import - from vllm.sequence import SequenceData - - mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len) - mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) - - # For encoder-decoder models, use encoder prompt token ids instead of - # decoder prompt to construct dummy seq_data for encoder profiling. - encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"] - - total_len = len(encoder_prompt_token_ids) - num_tokens_to_pad = max(total_len, seq_len) - total_len - encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - - return DummyData( - seq_data=SequenceData.from_seqs(encoder_prompt_token_ids), - multi_modal_data=None, - multi_modal_placeholders=None, - ) - - def get_decoder_dummy_data( - self, - seq_len: int, - ) -> DummyData: - # Avoid circular import - from vllm.sequence import SequenceData - - (mm_inputs, total_placeholders_by_modality - ) = self.get_and_validate_mm_inputs(seq_len) - prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain " - "multi-modal inputs to fail during inference, even when " - "the input text is short. To avoid this, you should " - "increase `max_model_len`, reduce `max_num_seqs`, " - "and/or reduce `mm_counts`.", seq_len, total_len, - total_placeholders_by_modality) + if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: + if total_len > seq_len and not is_encoder_data: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain " + "multi-modal inputs to fail during inference, even when " + "the input text is short. To avoid this, you should " + "increase `max_model_len`, reduce `max_num_seqs`, " + "and/or reduce `mm_counts`.", seq_len, total_len, + total_placeholders_by_modality) + + num_tokens_to_pad = max(total_len, seq_len) - total_len + prompt_token_ids.extend([0] * num_tokens_to_pad) return DummyData( - seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + seq_data=SequenceData.from_seqs(prompt_token_ids), multi_modal_data=None, multi_modal_placeholders=None, ) @@ -240,5 +216,5 @@ def get_decoder_dummy_data( return DummyData( seq_data=SequenceData.from_seqs(prompt_token_ids), multi_modal_data=mm_inputs["mm_kwargs"], - multi_modal_placeholders=mm_inputs["mm_placeholders"], + multi_modal_placeholders=placeholders_by_modality, )