Skip to content

Commit 05fb671

Browse files
[Bugfix] Clean up multi-modal processors (#14417)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 12c29a8 commit 05fb671

File tree

12 files changed

+79
-76
lines changed

12 files changed

+79
-76
lines changed

vllm/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,15 @@ def compute_hash(self) -> str:
24052405
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
24062406
return hash_str
24072407

2408+
def get_limit_per_prompt(self, modality: str) -> int:
2409+
"""
2410+
Get the maximum number of input items allowed per prompt
2411+
for the given modality.
2412+
2413+
If not set by the user, this defaults to `1`.
2414+
"""
2415+
return self.limit_per_prompt.get(modality, 1)
2416+
24082417
# TODO: Add configs to init vision tower or not.
24092418

24102419

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from transformers import BatchFeature
1515

1616
from vllm.config import VllmConfig
17-
from vllm.logger import init_logger
1817
from vllm.model_executor import SamplingMetadata
1918
from vllm.model_executor.layers.quantization import QuantizationConfig
2019
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -25,8 +24,8 @@
2524
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2625
ImageSize, MultiModalDataItems)
2726
from vllm.multimodal.processing import (BaseMultiModalProcessor,
28-
BaseProcessingInfo, ProcessingCache,
29-
PromptReplacement, PromptUpdate)
27+
BaseProcessingInfo, PromptReplacement,
28+
PromptUpdate)
3029
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3130
from vllm.sequence import IntermediateTensors
3231
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
@@ -42,8 +41,6 @@
4241
init_vllm_registered_model, maybe_prefix,
4342
merge_multimodal_embeddings)
4443

45-
logger = init_logger(__name__)
46-
4744
# The image token id may be various
4845
_IMAGE_TOKEN = "<image>"
4946

@@ -216,30 +213,6 @@ def get_dummy_processor_inputs(
216213
class DeepseekVL2MultiModalProcessor(
217214
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
218215

219-
def __init__(
220-
self,
221-
info: DeepseekVL2ProcessingInfo,
222-
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
223-
*,
224-
cache: Optional[ProcessingCache] = None,
225-
enable_sanity_checks: bool = True) -> None:
226-
super().__init__(
227-
info,
228-
dummy_inputs,
229-
cache=cache,
230-
enable_sanity_checks=enable_sanity_checks,
231-
)
232-
233-
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
234-
if self.cache is not None and mm_limit["image"] > 2:
235-
# The processor output depends on the number of images passed,
236-
# making it incompatible with processing cache which is supposed
237-
# to be invariant of how many images are passed per prompt
238-
self.cache = None
239-
logger.warning_once(
240-
f"{type(self).__name__} does not support processing cache with "
241-
"image limit larger than 2.")
242-
243216
def _call_hf_processor(
244217
self,
245218
prompt: str,
@@ -316,6 +289,31 @@ def get_replacement_deepseek_vl2(item_idx: int):
316289
)
317290
]
318291

292+
def _cached_apply_hf_processor(
293+
self,
294+
prompt: Union[str, list[int]],
295+
mm_data_items: MultiModalDataItems,
296+
hf_processor_mm_kwargs: Mapping[str, object],
297+
) -> tuple[list[int], MultiModalKwargs, bool]:
298+
# The processor logic is different for len(images) <= 2 vs > 2
299+
# Since the processing cache assumes that the processor output is
300+
# invariant of how many images are passed per prompt, we only
301+
# perform caching for the most common case
302+
if mm_data_items.get_count("image", strict=False) > 2:
303+
# This code path corresponds to the cache being disabled
304+
return self._apply_hf_processor_main(
305+
prompt=prompt,
306+
mm_items=mm_data_items,
307+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
308+
enable_hf_prompt_update=True,
309+
)
310+
311+
return super()._cached_apply_hf_processor(
312+
prompt=prompt,
313+
mm_data_items=mm_data_items,
314+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
315+
)
316+
319317

320318
@MULTIMODAL_REGISTRY.register_processor(
321319
DeepseekVL2MultiModalProcessor,

vllm/model_executor/models/h2ovl.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,19 @@
88
# Licensed under Apache 2.0 License [see LICENSE for details]
99
# --------------------------------------------------------
1010
from collections.abc import Mapping, Sequence
11-
from typing import Optional
11+
from typing import Optional, Union
1212

1313
import torch
1414
from PIL import Image
1515
from transformers import PretrainedConfig
1616

17-
from vllm.logger import init_logger
1817
from vllm.model_executor.layers.quantization import QuantizationConfig
1918
from vllm.multimodal import MULTIMODAL_REGISTRY
2019
from vllm.multimodal.inputs import MultiModalKwargs
2120
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2221
MultiModalDataItems)
23-
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
24-
PromptUpdate, PromptUpdateDetails)
25-
from vllm.multimodal.profiling import BaseDummyInputsBuilder
22+
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
23+
PromptUpdateDetails)
2624
from vllm.transformers_utils.tokenizer import AnyTokenizer
2725

2826
from .intern_vit import InternVisionModel
@@ -32,8 +30,6 @@
3230
InternVLMultiModalProcessor, build_transform,
3331
find_closest_aspect_ratio, get_internvl_target_ratios)
3432

35-
logger = init_logger(__name__)
36-
3733

3834
def resolve_h2ovl_min_max_num(
3935
*,
@@ -465,29 +461,6 @@ def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
465461
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
466462
):
467463

468-
def __init__(self,
469-
info: H2OVLProcessingInfo,
470-
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
471-
*,
472-
cache: Optional[ProcessingCache] = None,
473-
enable_sanity_checks: bool = True) -> None:
474-
super().__init__(
475-
info,
476-
dummy_inputs,
477-
cache=cache,
478-
enable_sanity_checks=enable_sanity_checks,
479-
)
480-
481-
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
482-
if self.cache is not None and mm_limit["image"] >= 2:
483-
# The processor output depends on the number of images passed,
484-
# making it incompatible with processing cache which is supposed
485-
# to be invariant of how many images are passed per prompt
486-
self.cache = None
487-
logger.warning_once(
488-
f"{type(self).__name__} does not support processing cache with "
489-
"multi-image support enabled.")
490-
491464
def _get_prompt_updates(
492465
self,
493466
mm_items: MultiModalDataItems,
@@ -543,6 +516,31 @@ def get_replacement_internvl(item_idx: int):
543516
)
544517
]
545518

519+
def _cached_apply_hf_processor(
520+
self,
521+
prompt: Union[str, list[int]],
522+
mm_data_items: MultiModalDataItems,
523+
hf_processor_mm_kwargs: Mapping[str, object],
524+
) -> tuple[list[int], MultiModalKwargs, bool]:
525+
# The processor logic is different for len(images) <= 1 vs > 1
526+
# Since the processing cache assumes that the processor output is
527+
# invariant of how many images are passed per prompt, we only
528+
# perform caching for the most common case
529+
if mm_data_items.get_count("image", strict=False) > 1:
530+
# This code path corresponds to the cache being disabled
531+
return self._apply_hf_processor_main(
532+
prompt=prompt,
533+
mm_items=mm_data_items,
534+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
535+
enable_hf_prompt_update=True,
536+
)
537+
538+
return super()._cached_apply_hf_processor(
539+
prompt=prompt,
540+
mm_data_items=mm_data_items,
541+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
542+
)
543+
546544

547545
@MULTIMODAL_REGISTRY.register_processor(
548546
H2OVLMultiModalProcessor,

vllm/model_executor/models/llava_next_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _get_max_video_frames(self, max_tokens: int) -> int:
133133

134134
def get_num_frames_with_most_features(self, seq_len: int) -> int:
135135
mm_config = self.ctx.get_mm_config()
136-
max_videos = mm_config.limit_per_prompt.get("video", 1)
136+
max_videos = mm_config.get_limit_per_prompt("video")
137137

138138
max_total_frames = self._get_max_video_frames(seq_len)
139139

vllm/model_executor/models/llava_onevision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def _get_max_video_frames(self, max_tokens: int) -> int:
206206

207207
def get_num_frames_with_most_features(self, seq_len: int) -> int:
208208
mm_config = self.ctx.get_mm_config()
209-
max_images = mm_config.limit_per_prompt.get("image", 1)
210-
max_videos = mm_config.limit_per_prompt.get("video", 1)
209+
max_images = mm_config.get_limit_per_prompt("image")
210+
max_videos = mm_config.get_limit_per_prompt("video")
211211

212212
max_image_tokens = self.get_max_image_tokens() * max_images
213213
max_total_frames = self._get_max_video_frames(seq_len -

vllm/model_executor/models/minicpmo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,9 @@ def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
201201

202202
def get_num_frames_with_most_features(self, seq_len: int) -> int:
203203
mm_config = self.ctx.get_mm_config()
204-
max_images = mm_config.limit_per_prompt.get("image", 1)
205-
max_videos = mm_config.limit_per_prompt.get("video", 1)
206-
max_audios = mm_config.limit_per_prompt.get("audio", 1)
204+
max_images = mm_config.get_limit_per_prompt("image")
205+
max_videos = mm_config.get_limit_per_prompt("video")
206+
max_audios = mm_config.get_limit_per_prompt("audio")
207207

208208
# count <image_idx></image_idx> tokens
209209
# which are not in get_max_image_tokens

vllm/model_executor/models/minicpmv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ def get_max_video_frames(self, max_tokens: int) -> int:
446446

447447
def get_num_frames_with_most_features(self, seq_len: int) -> int:
448448
mm_config = self.ctx.get_mm_config()
449-
max_images = mm_config.limit_per_prompt.get("image", 1)
450-
max_videos = mm_config.limit_per_prompt.get("video", 1)
449+
max_images = mm_config.get_limit_per_prompt("image")
450+
max_videos = mm_config.get_limit_per_prompt("video")
451451

452452
# count <image_idx></image_idx> tokens
453453
# which are not in get_max_image_tokens

vllm/model_executor/models/pixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
6868
image_token_id = mm_encoder.special_ids.img
6969

7070
mm_config = ctx.get_mm_config()
71-
num_images = mm_config.limit_per_prompt.get("image", 1)
71+
num_images = mm_config.get_limit_per_prompt("image")
7272

7373
# dummy size
7474
size = 256

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,8 @@ def _get_max_video_frames(self, max_tokens: int) -> int:
911911

912912
def get_num_frames_with_most_features(self, seq_len: int) -> int:
913913
mm_config = self.ctx.get_mm_config()
914-
max_images = mm_config.limit_per_prompt.get("image", 1)
915-
max_videos = mm_config.limit_per_prompt.get("video", 1)
914+
max_images = mm_config.get_limit_per_prompt("image")
915+
max_videos = mm_config.get_limit_per_prompt("video")
916916

917917
max_image_tokens = self.get_max_image_tokens() * max_images
918918
max_total_frames = self._get_max_video_frames(seq_len -

vllm/multimodal/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,10 +984,10 @@ def _to_mm_items(
984984
before passing them to :meth:`_get_hf_mm_data`.
985985
"""
986986
mm_items = self.data_parser.parse_mm_data(mm_data)
987+
mm_config = self.info.ctx.get_mm_config()
987988

988-
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
989989
for modality, items in mm_items.items():
990-
limit = mm_limits.get(modality, 1)
990+
limit = mm_config.get_limit_per_prompt(modality)
991991
if len(items) > limit:
992992
raise ValueError(
993993
f"You set {modality}={limit} (or defaulted to 1) in "

0 commit comments

Comments
 (0)