diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 4cd2dbdb4f98..202516f4c209 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -7,11 +7,11 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding) +from vllm import LLM, SamplingParams from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) -from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID, - MllamaForConditionalGeneration) +from vllm.model_executor.models.mllama import MllamaForConditionalGeneration from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs @@ -21,6 +21,7 @@ from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 3 +MLLAMA_IMAGE_TOKEN_ID = 128256 LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] @@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, ) +@large_gpu_test(min_gb=48) +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +def test_explicit_implicit_prompt( + image_assets: _ImageAssets, + model: str, + dtype: str, + max_tokens: int, +): + stop_sign = image_assets[0].pil_image + # yapf: disable + prompts = [ + # explicit prompt + { + "encoder_prompt": { + "prompt": "<|image|>", + "multi_modal_data": {"image": stop_sign}, + }, + "decoder_prompt": { + "prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501 + } + }, + { + "encoder_prompt": "Not <|image|>", + "decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 + }, + # implicit prompt + { + "prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 + "multi_modal_data": {"image": stop_sign}, + }, + { + "prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 + }, + ] + # yapf: enable + llm = LLM( + model=model, + dtype=dtype, + max_model_len=4096, + max_num_seqs=2, + tensor_parallel_size=1, + enforce_eager=True, + ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=max_tokens, + ) + outputs = llm.generate(prompts, sampling_params) + n_prompts = len(prompts) + explicit_outputs = outputs[:n_prompts // 2] + implicit_outputs = outputs[n_prompts // 2:] + for exp_output, imp_output in zip(explicit_outputs, implicit_outputs): + assert exp_output.outputs[0].text == imp_output.outputs[0].text + + @large_gpu_test(min_gb=48) @pytest.mark.core_model @pytest.mark.parametrize("model", models) @@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, images=images) +class DummyModel: + image_token_id = MLLAMA_IMAGE_TOKEN_ID + + @pytest.mark.core_model @pytest.mark.parametrize( "input_indices_and_output", @@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None: use_cuda_graph=False, ) - dummy: dict[str, str] = {} + dummy = DummyModel() cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\ .get_cross_attention_mask(dummy, @@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None: use_cuda_graph=False, ) - dummy: dict[str, str] = {} + dummy = DummyModel() full_text_row_masked_out_mask = MllamaForConditionalGeneration\ .get_full_text_row_masked_out_mask(dummy, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 6244056c7474..67ef8b17ab8c 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -85,6 +85,14 @@ def _test_processing_correctness( partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), } + tokenizer_encode_kwargs = {} + if model_config.hf_config.model_type == "mllama": + # For Mllama, tokenizer will always add bos_token at the beginning of + # prompt by default, causing hf_processor outputs incorrect token ids. + # So we need use `add_special_tokens=False` here to leave bos_token + # to be added by the processor. + tokenizer_encode_kwargs = {"add_special_tokens": False} + for batch_idx in range(num_batches): mm_data = { k: @@ -122,7 +130,7 @@ def _test_processing_correctness( f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") baseline_tokenized_result = baseline_processor.apply( - tokenizer.encode(prompt), + tokenizer.encode(prompt, **tokenizer_encode_kwargs), mm_data=mm_data, hf_processor_mm_kwargs={}, ) @@ -131,7 +139,7 @@ def _test_processing_correctness( f"Failed ({batch_idx=}, {prompt=}, {mm_data=})") cached_tokenized_result = cached_processor.apply( - tokenizer.encode(prompt), + tokenizer.encode(prompt, **tokenizer_encode_kwargs), mm_data=mm_data, hf_processor_mm_kwargs={}, ) @@ -155,6 +163,7 @@ def _test_processing_correctness( "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "meta-llama/Llama-3.2-11B-Vision-Instruct", "TIGER-Lab/Mantis-8B-siglip-llama3", "mistral-community/pixtral-12b", "openbmb/MiniCPM-o-2_6", diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 656f2f2b766e..bc5856990da6 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, Tuple, Union, cast from typing_extensions import assert_never @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -495,6 +496,51 @@ def _build_enc_dec_llm_inputs( decoder=decoder_inputs, ) + def _separate_enc_dec_inputs_from_mm_processor_outputs( + self, + inputs: SingletonInputs, + decoder_inputs_to_override: Optional[SingletonInputs] = None, + ) -> Tuple[SingletonInputs, SingletonInputs]: + """ + For encoder/decoder models only: + Separate Encoder/Decoder inputs from a MultiModalEncDecInputs + """ + encoder_inputs: SingletonInputs + decoder_inputs: SingletonInputs + if inputs["type"] == "multimodal": + # Multimodal data inputs + assert ("encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs) + inputs = cast(MultiModalEncDecInputs, inputs) + encoder_inputs = token_inputs( + prompt=inputs["encoder_prompt"], + prompt_token_ids=inputs["encoder_prompt_token_ids"], + ) + if decoder_inputs_to_override is not None: + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=decoder_inputs_to_override.get("prompt", ""), + prompt_token_ids=decoder_inputs_to_override[ + "prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_placeholders=inputs["mm_placeholders"], + ) + else: + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=inputs["prompt"], + prompt_token_ids=inputs["prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_placeholders=inputs["mm_placeholders"], + ) + elif inputs["type"] == "token": + # Text-only inputs + encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) + decoder_inputs = decoder_inputs_to_override or inputs + else: + assert_never(inputs) # type: ignore[arg-type] + return encoder_inputs, decoder_inputs + def _process_encoder_decoder_prompt( self, prompt: PromptType, @@ -539,7 +585,6 @@ def _process_encoder_decoder_prompt( prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None else: @@ -547,13 +592,28 @@ def _process_encoder_decoder_prompt( decoder_input, request_id=request_id, ) + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + encoder_inputs, decoder_inputs)) else: - encoder_inputs = self._prompt_to_llm_inputs( + inputs = self._prompt_to_llm_inputs( prompt, request_id=request_id, ) + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + inputs)) + else: + encoder_inputs = inputs - decoder_inputs = None + decoder_inputs = None return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) @@ -583,13 +643,29 @@ async def _process_encoder_decoder_prompt_async( encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) + + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + encoder_inputs, decoder_inputs)) else: - encoder_inputs = await self._prompt_to_llm_inputs_async( + inputs = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, ) + if self.model_config.is_multimodal_model and ( + self._can_process_multimodal()): + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._separate_enc_dec_inputs_from_mm_processor_outputs( + inputs)) + else: + encoder_inputs = inputs - decoder_inputs = None + decoder_inputs = None return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index cd421443981f..87b7a7631e42 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -350,7 +350,8 @@ def dummy_data_for_profiling( ) processor = mm_registry.create_processor(model_config, tokenizer) profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_dummy_data(seq_len) + dummy_data = profiler.get_dummy_data( + seq_len, is_encoder_data=is_encoder_data) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d1cb04cdb242..3ca22d346b79 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -23,14 +23,15 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama -from PIL import Image +from PIL.Image import Image from torch import nn +from transformers import BatchFeature, MllamaConfig from transformers.modeling_outputs import (BaseModelOutput, CausalLMOutputWithPast) from transformers.models.mllama.image_processing_mllama import ( get_optimal_tiled_canvas) from transformers.models.mllama.processing_mllama import ( - get_cross_attention_token_mask) + MllamaProcessor, get_cross_attention_token_mask) import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType @@ -38,8 +39,6 @@ from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, - InputContext, TokenInputs, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -54,8 +53,13 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SequenceData -from vllm.utils import is_list_of +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataDict, MultiModalDataItems) +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -63,8 +67,6 @@ from .utils import maybe_prefix logger = init_logger(__name__) -MLLAMA_IMAGE_TOKEN_ID = 128256 -MLLAMA_IMAGE_TOKEN = "<|image|>" class MllamaImagePixelInputs(TypedDict): @@ -81,158 +83,191 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs -def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: - num_images = 0 - for token_id in prompt_token_ids[::-1]: - if token_id == MLLAMA_IMAGE_TOKEN_ID: - num_images += 1 - elif num_images > 0: - break - return num_images - - -def input_processor_for_mllama( - ctx: InputContext, - inputs: EncoderDecoderInputs, -) -> EncoderDecoderInputs: - # Example input to processor: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000], - # }, - # } - - # move encoder prompt to decoder - dec_inputs = TokenInputs(**inputs["encoder"]) - - multi_modal_data = dec_inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - # text-only - return EncoderDecoderInputs( - encoder=token_inputs([]), - decoder=dec_inputs, +def calc_token_per_chunk(image_size: int) -> int: + assert image_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = (image_size // 14)**2 + 1 + return token_per_chunk + + +class MllamaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> MllamaConfig: + return self.ctx.get_hf_config(MllamaConfig) + + def get_hf_processor(self) -> MllamaProcessor: + return self.ctx.get_hf_processor(MllamaProcessor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_token_per_chunk_from_config(self) -> int: + image_size = self.get_hf_config().vision_config.image_size + return calc_token_per_chunk(image_size) + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + vision_config = self.get_hf_config().vision_config + token_per_chunk = self.get_token_per_chunk_from_config() + mm_max_tokens = vision_config.max_num_tiles * token_per_chunk + return {"image": mm_max_tokens} + + def get_num_tiles_per_image(self, image_height: int, + image_width: int) -> int: + vision_config = self.get_hf_config().vision_config + max_num_tiles = vision_config.max_num_tiles + image_size = vision_config.image_size + tiled_height, tiled_width = get_optimal_tiled_canvas( + image_height, + image_width, + max_num_tiles, + tile_size=image_size, + ) + num_tiles_height = tiled_height // image_size + num_tiles_width = tiled_width // image_size + return num_tiles_height * num_tiles_width + + def get_image_size_with_most_features(self) -> ImageSize: + vision_config = self.get_hf_config().vision_config + image_size = vision_config.image_size + max_num_tiles = vision_config.max_num_tiles + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=max_num_tiles * image_size, width=image_size) + + +class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + hf_processor = self.info.get_hf_processor() + image_token: str = hf_processor.image_token + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, ) - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_data = [image_data] - - assert is_list_of(image_data, Image.Image) - - num_image_tokens = dec_inputs['prompt_token_ids'].count( - MLLAMA_IMAGE_TOKEN_ID) - if num_image_tokens != len(image_data): - raise ValueError( - f"The number of image tokens ({num_image_tokens}) must be" - f" the same as the number of images ({len(image_data)})") - - # Since only the last group of consecutive images - # are attended by the decoded tokens, we only need to - # get the number of tiles for those images. - num_decode_images = _get_num_image_in_last_group( - dec_inputs["prompt_token_ids"]) - - hf_config = ctx.model_config.hf_config - vision_config = hf_config.vision_config - - num_tiles = 0 - for image in image_data[::-1]: - width, height = image.size - tile_size = vision_config.image_size - canvas_height, canvas_width = get_optimal_tiled_canvas( - image_height=height, - image_width=width, - max_image_tiles=vision_config.max_num_tiles, - tile_size=tile_size, + +class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] + ): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + if mm_data: + num_tiles = [ + self.info.get_num_tiles_per_image(img.height, img.width) + for img in mm_data["images"] + ] + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + processed_outputs["num_tiles"] = torch.tensor(num_tiles) + for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): + processed_outputs[k] = processed_outputs[k].squeeze(0) + # Example input to encoder and decoder: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000], + # }, + # } + processed_token_ids = processed_outputs.pop("input_ids") + start_idx, end_idx = 0, processed_token_ids.size(1) + processed_prompt_text = tokenizer.decode(processed_token_ids[0]) + + hf_processor = self.info.get_hf_processor() + bos_token = hf_processor.bos_token + # Remove the bos_token from the start of prompt, + # because we all know there would be image_token. + if processed_prompt_text.startswith(bos_token): + start_idx += 1 + # Remove the bos_token from the end of prompt, + # because text is empty in this case. + if processed_prompt_text.endswith(bos_token): + end_idx -= 1 + processed_outputs[ + "input_ids"] = processed_token_ids[:, start_idx:end_idx] + else: + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + aspect_ratio_ids=MultiModalFieldConfig.batched("image"), + aspect_ratio_mask=MultiModalFieldConfig.batched("image"), + num_tiles=MultiModalFieldConfig.batched("image"), ) - num_tiles_height = canvas_height // tile_size - num_tiles_width = canvas_width // tile_size - num_tiles += num_tiles_height * num_tiles_width - num_decode_images -= 1 - if num_decode_images == 0: - break - - # Set encoder prompt length based on the number of tiles. - # This tells the block manager to allocate correct number - # of slots for encoder tokens. - assert vision_config.image_size % 14 == 0, \ - "chunk size should be multiple of 14" - token_per_chunk = (vision_config.image_size // 14)**2 + 1 - num_tokens = num_tiles * token_per_chunk - - # Example output from processor: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128256, 128256, ..., 128256], - # 'prompt': '<|image|><|image|>...<|image|>', - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # } - return EncoderDecoderInputs( - encoder=token_inputs( - prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens, - prompt=MLLAMA_IMAGE_TOKEN * num_tokens, - multi_modal_data=multi_modal_data, - ), - decoder=dec_inputs, - ) - - -def get_max_mllama_image_tokens(ctx: InputContext) -> int: - hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 - return hf_config.vision_config.max_num_tiles * token_per_chunk - - -def dummy_decoder_seq_data(seq_len: int, num_images: int): - # <|image|> * num_images + 0 * (seq_len - num_images) - assert seq_len >= num_images, \ - "seq_len should be greater than or equal to num_images" - - return SequenceData.from_prompt_token_counts( - (MLLAMA_IMAGE_TOKEN_ID, num_images), - (0, seq_len - num_images), - ) - - -def dummy_encoder_seq_data(ctx: InputContext, num_images: int): - num_tokens = get_max_mllama_image_tokens(ctx) * num_images - - return SequenceData.from_prompt_token_counts( - (MLLAMA_IMAGE_TOKEN_ID, num_tokens)) - - -def dummy_image(num_images: int, ): - width = height = 1024 - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - return DummyData(dummy_decoder_seq_data(seq_len, num_images)) - - -def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - return DummyData(dummy_encoder_seq_data(ctx, num_images), - dummy_image(num_images)) + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + data = mm_data.get("image", []) + num_images = 1 if isinstance(data, Image) else len(data) + image_token_id = self.info.get_hf_config().image_token_index + return [image_token_id] * num_images + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + token_per_chunk = self.info.get_token_per_chunk_from_config() + image_token_id = self.info.get_hf_config().image_token_index + + def get_replacement_mllama(item_idx): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + num_tile = self.info.get_num_tiles_per_image( + image_height=image_size.height, + image_width=image_size.width, + ) + num_tokens = num_tile * token_per_chunk + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_mllama, + ) + ] def _prepare_aspect_ratio_attention_mask( @@ -1107,11 +1142,9 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) -@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) -@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, + info=MllamaProcessingInfo, + dummy_inputs=MllamaDummyInputsBuilder) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -1120,7 +1153,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: MllamaConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.quant_config = quant_config self.vocab_size = config.text_config.vocab_size @@ -1130,6 +1163,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pad_token_id = \ config.pad_token_id if config.pad_token_id is not None else -1 self.image_size = config.vision_config.image_size + self.image_token_id = config.image_token_index self.vision_model = MllamaVisionModel(config.vision_config, quant_config, @@ -1204,48 +1238,12 @@ def _parse_and_validate_image_input(self, **kwargs: object): if pixel_values is not None: assert aspect_ratio_ids is not None assert aspect_ratio_mask is not None - max_num_images = max([len(x[0]) for x in pixel_values]) - if max_num_images == 0: - raise ValueError("No images provided.") - max_num_tiles = max( - max([len(x) for x in y[0]]) for y in pixel_values) - device = next(self.multi_modal_projector.parameters()).device - bsz = len(pixel_values) - out_num_tiles = [] - out_images = torch.zeros( - bsz, - max_num_images, - max_num_tiles, - 3, - self.image_size, - self.image_size, - dtype=torch.float32, - device=device, - ) - out_ar_ids = torch.ones(bsz, - max_num_images, - dtype=torch.int64, - device=device) - out_ar_mask = torch.zeros(bsz, - max_num_images, - max_num_tiles, - dtype=torch.int64, - device=device) - for b in range(len(pixel_values)): - _num_tiles = [] - for i in range(len(pixel_values[b][0])): - img = pixel_values[b][0][i] - out_images[b, i, :img.shape[0]] = img - out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] - out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] - _num_tiles.append(img.shape[0]) - out_num_tiles.append(_num_tiles) return MllamaImagePixelInputs( type="pixel_values", - data=out_images, - aspect_ratio_ids=out_ar_ids, - aspect_ratio_mask=out_ar_mask, + data=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, ) if image_embeds is not None: @@ -1312,7 +1310,7 @@ def get_cross_attention_mask( batch_token_ids.append(token_ids[start:start + seq_len]) start += seq_len sparse_mask = [ - get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID) + get_cross_attention_token_mask(t, self.image_token_id) for t in batch_token_ids ] @@ -1384,8 +1382,8 @@ def forward( # block manager to allocate blocks for those images only. # See input_processor_for_mllama() for more details. num_tiles_tensor = kwargs.pop("num_tiles") - num_tiles = [t[0].tolist() for t in num_tiles_tensor] - num_tokens_per_tile = (self.image_size // 14)**2 + 1 + num_tiles = [t.tolist() for t in num_tiles_tensor] + num_tokens_per_tile = calc_token_per_chunk(self.image_size) actual_encoder_seq_lens = [ sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles ] diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 5f9593ee8b20..25ca8d1e71f7 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict): For each modality, information about the placeholder tokens in :code:`prompt_token_ids`. """ + + +class MultiModalEncDecInputs(MultiModalInputs): + """ + Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor` + ready to be passed to vLLM internals. + """ + + encoder_prompt: str + """The processed encoder prompt text.""" + + encoder_prompt_token_ids: list[int] + """The processed token IDs of the encoder prompt.""" + + encoder_token_type_ids: NotRequired[list[int]] + """The token type IDs of the encoder prompt.""" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index d704fa59b96a..74479f5ffad5 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -20,9 +20,9 @@ from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, - PlaceholderRange) +from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, + MultiModalKwargsItem, PlaceholderRange) from .parse import MultiModalDataItems, MultiModalDataParser if TYPE_CHECKING: @@ -1293,3 +1293,57 @@ def apply( mm_hashes=mm_hashes, mm_placeholders=mm_placeholder_ranges, ) + + +class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): + + @abstractmethod + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + """Create input prompt for the encoder.""" + raise NotImplementedError + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalEncDecInputs: + """ + Process multi-modal inputs to be used in vLLM. + The main processing steps are modified to fit encoder-decoder model: + 1. Create encoder prompt from input prompt text. + 2. Apply the HF processor on encoder prompt. + 3. Copy the input prompt text as decoder prompt inputs. + """ + encoder_prompt = self.create_encoder_prompt(prompt, mm_data) + encoder_inputs = super().apply( + encoder_prompt, + mm_data, + hf_processor_mm_kwargs, + ) + + # We assumed the decoder prompt text is copied from + # the original encoder prompt without extra process + tokenizer = self.info.get_tokenizer() + if isinstance(prompt, str): + decoder_prompt = prompt + decoder_prompt_ids = encode_tokens(tokenizer, + prompt, + add_special_tokens=False) + else: + decoder_prompt = decode_tokens(tokenizer, prompt) + decoder_prompt_ids = prompt + + mm_inputs = MultiModalEncDecInputs( + encoder_prompt=encoder_inputs["prompt"], + encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], + **encoder_inputs) + mm_inputs.update({ + "prompt": decoder_prompt, + "prompt_token_ids": decoder_prompt_ids + }) + return mm_inputs diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 5dd754854044..81c92b38f8e9 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -144,7 +144,11 @@ def _get_dummy_mm_inputs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) - def get_dummy_data(self, seq_len: int) -> DummyData: + def get_dummy_data( + self, + seq_len: int, + is_encoder_data: bool = False, + ) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData @@ -183,16 +187,18 @@ def get_dummy_data(self, seq_len: int) -> DummyData: total_len = len(prompt_token_ids) # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - logger.warning( - "The context length (%d) of the model is too short " - "to hold the multi-modal embeddings in the worst case " - "(%d tokens in total, out of which %s are reserved for " - "multi-modal embeddings). This may cause certain multi-modal " - "inputs to fail during inference, even when the input text is " - "short. To avoid this, you should increase `max_model_len`, " - "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, - total_len, total_placeholders_by_modality) + if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: + if total_len > seq_len: + logger.warning( + "The context length (%d) of the model is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain " + "multi-modal inputs to fail during inference, even when " + "the input text is short. To avoid this, you should " + "increase `max_model_len`, reduce `max_num_seqs`, " + "and/or reduce `mm_counts`.", seq_len, total_len, + total_placeholders_by_modality) return DummyData( seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),