|
8 | 8 | # Licensed under Apache 2.0 License [see LICENSE for details]
|
9 | 9 | # --------------------------------------------------------
|
10 | 10 | from collections.abc import Mapping, Sequence
|
11 |
| -from typing import Optional |
| 11 | +from typing import Optional, Union |
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | from PIL import Image
|
15 | 15 | from transformers import PretrainedConfig
|
16 | 16 |
|
17 |
| -from vllm.logger import init_logger |
18 | 17 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
19 | 18 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
20 | 19 | from vllm.multimodal.inputs import MultiModalKwargs
|
21 | 20 | from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
22 | 21 | MultiModalDataItems)
|
23 |
| -from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, |
24 |
| - PromptUpdate, PromptUpdateDetails) |
25 |
| -from vllm.multimodal.profiling import BaseDummyInputsBuilder |
| 22 | +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, |
| 23 | + PromptUpdateDetails) |
26 | 24 | from vllm.transformers_utils.tokenizer import AnyTokenizer
|
27 | 25 |
|
28 | 26 | from .intern_vit import InternVisionModel
|
|
32 | 30 | InternVLMultiModalProcessor, build_transform,
|
33 | 31 | find_closest_aspect_ratio, get_internvl_target_ratios)
|
34 | 32 |
|
35 |
| -logger = init_logger(__name__) |
36 |
| - |
37 | 33 |
|
38 | 34 | def resolve_h2ovl_min_max_num(
|
39 | 35 | *,
|
@@ -465,29 +461,6 @@ def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
|
465 | 461 | class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
466 | 462 | ):
|
467 | 463 |
|
468 |
| - def __init__(self, |
469 |
| - info: H2OVLProcessingInfo, |
470 |
| - dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]", |
471 |
| - *, |
472 |
| - cache: Optional[ProcessingCache] = None, |
473 |
| - enable_sanity_checks: bool = True) -> None: |
474 |
| - super().__init__( |
475 |
| - info, |
476 |
| - dummy_inputs, |
477 |
| - cache=cache, |
478 |
| - enable_sanity_checks=enable_sanity_checks, |
479 |
| - ) |
480 |
| - |
481 |
| - mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt |
482 |
| - if self.cache is not None and mm_limit["image"] >= 2: |
483 |
| - # The processor output depends on the number of images passed, |
484 |
| - # making it incompatible with processing cache which is supposed |
485 |
| - # to be invariant of how many images are passed per prompt |
486 |
| - self.cache = None |
487 |
| - logger.warning_once( |
488 |
| - f"{type(self).__name__} does not support processing cache with " |
489 |
| - "multi-image support enabled.") |
490 |
| - |
491 | 464 | def _get_prompt_updates(
|
492 | 465 | self,
|
493 | 466 | mm_items: MultiModalDataItems,
|
@@ -543,6 +516,31 @@ def get_replacement_internvl(item_idx: int):
|
543 | 516 | )
|
544 | 517 | ]
|
545 | 518 |
|
| 519 | + def _cached_apply_hf_processor( |
| 520 | + self, |
| 521 | + prompt: Union[str, list[int]], |
| 522 | + mm_data_items: MultiModalDataItems, |
| 523 | + hf_processor_mm_kwargs: Mapping[str, object], |
| 524 | + ) -> tuple[list[int], MultiModalKwargs, bool]: |
| 525 | + # The processor logic is different for len(images) <= 1 vs > 1 |
| 526 | + # Since the processing cache assumes that the processor output is |
| 527 | + # invariant of how many images are passed per prompt, we only |
| 528 | + # perform caching for the most common case |
| 529 | + if mm_data_items.get_count("image", strict=False) > 1: |
| 530 | + # This code path corresponds to the cache being disabled |
| 531 | + return self._apply_hf_processor_main( |
| 532 | + prompt=prompt, |
| 533 | + mm_items=mm_data_items, |
| 534 | + hf_processor_mm_kwargs=hf_processor_mm_kwargs, |
| 535 | + enable_hf_prompt_update=True, |
| 536 | + ) |
| 537 | + |
| 538 | + return super()._cached_apply_hf_processor( |
| 539 | + prompt=prompt, |
| 540 | + mm_data_items=mm_data_items, |
| 541 | + hf_processor_mm_kwargs=hf_processor_mm_kwargs, |
| 542 | + ) |
| 543 | + |
546 | 544 |
|
547 | 545 | @MULTIMODAL_REGISTRY.register_processor(
|
548 | 546 | H2OVLMultiModalProcessor,
|
|
0 commit comments