From aebfb23d56cecb1861fa5fee62e9dfb921f6d3b2 Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 19 Feb 2025 13:59:19 -0800 Subject: [PATCH 01/10] init commit --- vllm/model_executor/models/paligemma.py | 227 ++++++++++++++++++------ 1 file changed, 170 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 65d810dc23bc..a92ad6839eb4 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -17,6 +17,15 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer + +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs) +from vllm.multimodal.processing import (BaseMultiModalProcessor,BaseProcessingInfo, + PromptReplacement, + PromptReplacementDetails) +from transformers import (BatchFeature, PaliGemmaConfig) +from vllm.multimodal.parse import (MultiModalDataItems, ImageProcessorItems) + +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -47,77 +56,77 @@ class PaliGemmaImageEmbeddingInputs(TypedDict): PaliGemmaImageEmbeddingInputs] -def get_max_paligemma_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(PaliGemmaConfig) - vision_config = hf_config.vision_config +# def get_max_paligemma_image_tokens(ctx: InputContext): +# hf_config = ctx.get_hf_config(PaliGemmaConfig) +# vision_config = hf_config.vision_config - return get_max_siglip_image_tokens(vision_config) +# return get_max_siglip_image_tokens(vision_config) -def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(PaliGemmaConfig) - vision_config = hf_config.vision_config - num_images = mm_counts["image"] +# def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, +# mm_counts: Mapping[str, int]): +# hf_config = ctx.get_hf_config(PaliGemmaConfig) +# vision_config = hf_config.vision_config +# num_images = mm_counts["image"] - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - ) +# seq_data, ranges = dummy_seq_data_for_siglip( +# vision_config, +# seq_len, +# num_images, +# image_token_id=hf_config.image_token_index, +# ) - mm_data = dummy_image_for_siglip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) +# mm_data = dummy_image_for_siglip(vision_config, num_images) +# return DummyData(seq_data, mm_data, ranges) -def input_processor_for_paligemma(ctx: InputContext, - inputs: DecoderOnlyInputs): +# def input_processor_for_paligemma(ctx: InputContext, +# inputs: DecoderOnlyInputs): - """ - The correct prompt format needs to be: - '' * image_feature_size + '' + prompt + '\n' +# """ +# The correct prompt format needs to be: +# '' * image_feature_size + '' + prompt + '\n' - See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 - """ # noqa +# See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 +# """ # noqa - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs +# multi_modal_data = inputs.get("multi_modal_data") +# if multi_modal_data is None or "image" not in multi_modal_data: +# return inputs - model_config = ctx.model_config - hf_config = ctx.get_hf_config(PaliGemmaConfig) +# model_config = ctx.model_config +# hf_config = ctx.get_hf_config(PaliGemmaConfig) - tokenizer = cached_get_tokenizer(model_config.tokenizer) - image_feature_size = hf_config.text_config.num_image_tokens - image_token_str = tokenizer.decode(hf_config.image_token_index) - bos_token = tokenizer.decode(hf_config.bos_token_id) - image_token_str_pad = image_token_str * image_feature_size - image_token_ids_pad = [hf_config.image_token_index] * image_feature_size +# tokenizer = cached_get_tokenizer(model_config.tokenizer) +# image_feature_size = hf_config.text_config.num_image_tokens +# image_token_str = tokenizer.decode(hf_config.image_token_index) +# bos_token = tokenizer.decode(hf_config.bos_token_id) +# image_token_str_pad = image_token_str * image_feature_size +# image_token_ids_pad = [hf_config.image_token_index] * image_feature_size - orig_prompt = inputs.get("prompt") - orig_prompt_ids = inputs.get("prompt_token_ids") +# orig_prompt = inputs.get("prompt") +# orig_prompt_ids = inputs.get("prompt_token_ids") - if orig_prompt is not None and image_token_str in orig_prompt: - logger.warning( - "The image token '%s' was detected in the prompt and " - "will be removed. Please follow the proper prompt format" - " documented on HuggingFace.", image_token_str) - orig_prompt = orig_prompt.replace(image_token_str, "") - orig_prompt_ids.remove(hf_config.image_token_index) +# if orig_prompt is not None and image_token_str in orig_prompt: +# logger.warning( +# "The image token '%s' was detected in the prompt and " +# "will be removed. Please follow the proper prompt format" +# " documented on HuggingFace.", image_token_str) +# orig_prompt = orig_prompt.replace(image_token_str, "") +# orig_prompt_ids.remove(hf_config.image_token_index) - new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" +# new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" - # The PaliGemma 2 tokenizer does not include a starting BOS token - if orig_prompt_ids[0] != hf_config.bos_token_id: - orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids +# # The PaliGemma 2 tokenizer does not include a starting BOS token +# if orig_prompt_ids[0] != hf_config.bos_token_id: +# orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids - new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline +# new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) +# # NOTE: Create a defensive copy of the original inputs +# return token_inputs(prompt_token_ids=new_token_ids, +# prompt=new_prompt, +# multi_modal_data=multi_modal_data) class PaliGemmaMultiModalProjector(nn.Module): @@ -131,11 +140,115 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear(image_features) return hidden_states +class PaliGemmaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(PaliGemmaConfig) + + def get_model_config(self): + return self.ctx.model_config + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_num_image_tokens()} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + return get_max_siglip_image_tokens(vision_config) + +class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # HF processor always adds placeholders even when there's no image + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + max_image_size = vision_config.image_size + + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + +class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image") + ) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + #model_config = self.info.get_model_config() + #tokenizer = cached_get_tokenizer(model_config.tokenizer) + #bos_token = tokenizer.decode(hf_config.bos_token_id) + + tokenizer = self.info.get_tokenizer() + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + bos_token_id = tokenizer.bos_token_id + assert isinstance(bos_token_id, int) + + return [ + PromptReplacement( + modality="image", + target=[bos_token_id], + replacement=PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ), + ) + ] -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) -@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) +@MULTIMODAL_REGISTRY.register_processor(PaliGemmaMultiModalProcessor, + info=PaliGemmaProcessingInfo, + dummy_inputs=PaliGemmaDummyInputsBuilder) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { From 4f192aa9c3f7a8c4c9c5d237df4bb7a608d3e135 Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 19 Feb 2025 15:26:22 -0800 Subject: [PATCH 02/10] test --- vllm/model_executor/models/paligemma.py | 39 +++++++++++++------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index a92ad6839eb4..185cea3b7aea 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -164,24 +164,6 @@ def get_num_image_tokens(self) -> int: return get_max_siglip_image_tokens(vision_config) class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - if not mm_data: - # HF processor always adds placeholders even when there's no image - tokenizer = self.info.get_tokenizer() - prompt_ids = tokenizer.encode(prompt) - return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - - return super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - ) def get_dummy_processor_inputs( self, @@ -200,13 +182,32 @@ def get_dummy_processor_inputs( height=max_image_size, num_images=num_images) } - + #print("kh get_dummy_processor_inputs",max_image_size,num_images) return ProcessorInputs( prompt_text="", mm_data=mm_data, ) class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # HF processor always adds placeholders even when there's no image + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + def _get_mm_fields_config( self, hf_inputs: BatchFeature, From 689cca694c1a48b26545496e6ac7c4ce057f84ec Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 19 Feb 2025 19:38:10 -0800 Subject: [PATCH 03/10] code clean Signed-off-by: Kyle Huang --- vllm/model_executor/models/paligemma.py | 135 +++++------------------- 1 file changed, 27 insertions(+), 108 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 185cea3b7aea..637cb8b2cb59 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -5,32 +5,25 @@ import torch from torch import nn -from transformers import PaliGemmaConfig +from transformers import BatchFeature, PaliGemmaConfig from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import cached_get_tokenizer - -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs) -from vllm.multimodal.processing import (BaseMultiModalProcessor,BaseProcessingInfo, - PromptReplacement, +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, PromptReplacementDetails) -from transformers import (BatchFeature, PaliGemmaConfig) -from vllm.multimodal.parse import (MultiModalDataItems, ImageProcessorItems) - from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_max_siglip_image_tokens) +from .siglip import SiglipVisionModel, get_max_siglip_image_tokens from .utils import (AutoWeightsLoader, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -56,79 +49,6 @@ class PaliGemmaImageEmbeddingInputs(TypedDict): PaliGemmaImageEmbeddingInputs] -# def get_max_paligemma_image_tokens(ctx: InputContext): -# hf_config = ctx.get_hf_config(PaliGemmaConfig) -# vision_config = hf_config.vision_config - -# return get_max_siglip_image_tokens(vision_config) - - -# def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, -# mm_counts: Mapping[str, int]): -# hf_config = ctx.get_hf_config(PaliGemmaConfig) -# vision_config = hf_config.vision_config -# num_images = mm_counts["image"] - -# seq_data, ranges = dummy_seq_data_for_siglip( -# vision_config, -# seq_len, -# num_images, -# image_token_id=hf_config.image_token_index, -# ) - -# mm_data = dummy_image_for_siglip(vision_config, num_images) -# return DummyData(seq_data, mm_data, ranges) - - -# def input_processor_for_paligemma(ctx: InputContext, -# inputs: DecoderOnlyInputs): - -# """ -# The correct prompt format needs to be: -# '' * image_feature_size + '' + prompt + '\n' - -# See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 -# """ # noqa - -# multi_modal_data = inputs.get("multi_modal_data") -# if multi_modal_data is None or "image" not in multi_modal_data: -# return inputs - -# model_config = ctx.model_config -# hf_config = ctx.get_hf_config(PaliGemmaConfig) - -# tokenizer = cached_get_tokenizer(model_config.tokenizer) -# image_feature_size = hf_config.text_config.num_image_tokens -# image_token_str = tokenizer.decode(hf_config.image_token_index) -# bos_token = tokenizer.decode(hf_config.bos_token_id) -# image_token_str_pad = image_token_str * image_feature_size -# image_token_ids_pad = [hf_config.image_token_index] * image_feature_size - -# orig_prompt = inputs.get("prompt") -# orig_prompt_ids = inputs.get("prompt_token_ids") - -# if orig_prompt is not None and image_token_str in orig_prompt: -# logger.warning( -# "The image token '%s' was detected in the prompt and " -# "will be removed. Please follow the proper prompt format" -# " documented on HuggingFace.", image_token_str) -# orig_prompt = orig_prompt.replace(image_token_str, "") -# orig_prompt_ids.remove(hf_config.image_token_index) - -# new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" - -# # The PaliGemma 2 tokenizer does not include a starting BOS token -# if orig_prompt_ids[0] != hf_config.bos_token_id: -# orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids - -# new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline - -# # NOTE: Create a defensive copy of the original inputs -# return token_inputs(prompt_token_ids=new_token_ids, -# prompt=new_prompt, -# multi_modal_data=multi_modal_data) - - class PaliGemmaMultiModalProjector(nn.Module): def __init__(self, vision_hidden_size: int, projection_dim: int): @@ -140,14 +60,12 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear(image_features) return hidden_states + class PaliGemmaProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(PaliGemmaConfig) - - def get_model_config(self): - return self.ctx.model_config - + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} @@ -163,8 +81,10 @@ def get_num_image_tokens(self) -> int: vision_config = hf_config.vision_config return get_max_siglip_image_tokens(vision_config) -class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): - + +class PaliGemmaDummyInputsBuilder( + BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): + def get_dummy_processor_inputs( self, seq_len: int, @@ -182,13 +102,15 @@ def get_dummy_processor_inputs( height=max_image_size, num_images=num_images) } - #print("kh get_dummy_processor_inputs",max_image_size,num_images) + return ProcessorInputs( prompt_text="", mm_data=mm_data, ) - -class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): + + +class PaliGemmaMultiModalProcessor( + BaseMultiModalProcessor[PaliGemmaProcessingInfo]): def _call_hf_processor( self, @@ -207,16 +129,14 @@ def _call_hf_processor( mm_data=mm_data, mm_kwargs=mm_kwargs, ) - + def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image") - ) - + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, @@ -225,10 +145,7 @@ def _get_prompt_replacements( ) -> list[PromptReplacement]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index - #model_config = self.info.get_model_config() - #tokenizer = cached_get_tokenizer(model_config.tokenizer) - #bos_token = tokenizer.decode(hf_config.bos_token_id) - + tokenizer = self.info.get_tokenizer() num_image_tokens = self.info.get_num_image_tokens() image_tokens = [image_token_id] * num_image_tokens @@ -247,9 +164,11 @@ def _get_prompt_replacements( ) ] -@MULTIMODAL_REGISTRY.register_processor(PaliGemmaMultiModalProcessor, - info=PaliGemmaProcessingInfo, - dummy_inputs=PaliGemmaDummyInputsBuilder) + +@MULTIMODAL_REGISTRY.register_processor( + PaliGemmaMultiModalProcessor, + info=PaliGemmaProcessingInfo, + dummy_inputs=PaliGemmaDummyInputsBuilder) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { From 2c1218a32b0375acf30203baf1079438e51dc9ec Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Fri, 21 Feb 2025 21:37:36 -0800 Subject: [PATCH 04/10] apply newline at the end of the prompt --- vllm/model_executor/models/paligemma.py | 88 ++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 637cb8b2cb59..9e832229156a 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -17,8 +18,15 @@ NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptReplacementDetails) + BaseProcessingInfo, + BoundPromptReplacement, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptReplacementDetails, + decode_tokens, encode_tokens, + find_text_matches, find_token_matches, + replace_text_matches, + replace_token_matches) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -164,6 +172,82 @@ def _get_prompt_replacements( ) ] + def _apply_prompt_replacements( + self, + token_ids: list[int], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + tokenizer = self.info.get_tokenizer() + + mm_token_matches = { + modality: find_token_matches(token_ids, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } + mm_match_counts = { + modality: len(matches) + for modality, matches in mm_token_matches.items() + } + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string + # replacement on the decoded token IDs, then encode them back. + if all( + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() + ): # yapf: disable + token_ids = replace_token_matches( + token_ids, + mm_token_matches, + mm_item_counts, + ) + + text = decode_tokens(tokenizer, token_ids) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_token_matches.items() + } + else: + text = decode_tokens(tokenizer, token_ids) + + mm_text_matches = { + modality: find_text_matches(text, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } + text = replace_text_matches( + text, + mm_text_matches, + mm_item_counts, + ) + + token_ids = encode_tokens(tokenizer, + text, + add_special_tokens=False) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_text_matches.items() + } + + placeholders = self._find_mm_placeholders( + matched_repls, + token_ids, + mm_item_counts, + ) + + # Force to add newline at the end of prompt due to paligemma's format + if len(token_ids) and token_ids[-1] != 109: + token_ids.append(109) + text += "\n" + + return token_ids, text, placeholders + @MULTIMODAL_REGISTRY.register_processor( PaliGemmaMultiModalProcessor, From 0df67c6b6df178fd3a1a32b9c2d4e98c7d1086e8 Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Mon, 24 Feb 2025 13:45:42 -0800 Subject: [PATCH 05/10] adding for paligemma2 --- vllm/model_executor/models/paligemma.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b00e324f91a0..35f4c8de6d77 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -131,6 +131,12 @@ def _call_hf_processor( # HF processor always adds placeholders even when there's no image tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode(prompt) + # Paligemma2 is NOT adding token at the beginning of the prompt + # Adding token (value 2) to adapt with prompt replacement + if len(prompt_ids) == 0: + prompt_ids = [2] + elif prompt_ids[0] != 2: + prompt_ids = [2] + prompt_ids return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return super()._call_hf_processor( From dbb49569d8d761dff8d940cc87f2d593679c0e57 Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Tue, 25 Feb 2025 21:10:38 -0800 Subject: [PATCH 06/10] formatting --- .../multimodal/processing/test_common.py | 2 + vllm/model_executor/models/paligemma.py | 136 ++++++------------ 2 files changed, 47 insertions(+), 91 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 331ffe82ec85..094d7afdc95a 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -173,6 +173,8 @@ def _test_processing_correctness( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "google/paligemma-3b-mix-224", + "google/paligemma2-3b-ft-docci-448", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 35f4c8de6d77..79c6f2410a74 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Sequence from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -14,22 +13,15 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - BoundPromptReplacement, - PlaceholderFeaturesInfo, - PromptReplacement, - PromptReplacementDetails, - decode_tokens, encode_tokens, - find_text_matches, find_token_matches, - replace_text_matches, - replace_token_matches) + BaseProcessingInfo, PromptReplacement, + PromptReplacementDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel, get_max_siglip_image_tokens @@ -132,11 +124,11 @@ def _call_hf_processor( tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode(prompt) # Paligemma2 is NOT adding token at the beginning of the prompt - # Adding token (value 2) to adapt with prompt replacement + # Adding token (tokenizer.bos_token_id) to adapt with prompt replacement if len(prompt_ids) == 0: - prompt_ids = [2] - elif prompt_ids[0] != 2: - prompt_ids = [2] + prompt_ids + prompt_ids = [tokenizer.bos_token_id] + elif prompt_ids[0] != tokenizer.bos_token_id: + prompt_ids = [tokenizer.bos_token_id] + prompt_ids return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return super()._call_hf_processor( @@ -152,6 +144,20 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + tokenizer = self.info.get_tokenizer() + # Adding token (tokenizer.bos_token_id) to match with _call_hf_processor + if len(prompt_tokens) == 0: + tokens_with_bos = [tokenizer.bos_token_id] + elif prompt_tokens[0] != tokenizer.bos_token_id: + tokens_with_bos = [tokenizer.bos_token_id] + prompt_tokens + else: + tokens_with_bos = prompt_tokens + return tokens_with_bos + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, @@ -168,92 +174,40 @@ def _get_prompt_replacements( bos_token_id = tokenizer.bos_token_id assert isinstance(bos_token_id, int) + # Adding token at the beginning based on add_bos_token variable. + # Always adding token at the end of the images tokens and before the text prompt + # according to Paligemma's format + if tokenizer.add_bos_token: + replacement_tokens = [bos_token_id] + image_tokens + [bos_token_id] + else: + replacement_tokens = image_tokens + [bos_token_id] + return [ PromptReplacement( modality="image", target=[bos_token_id], replacement=PromptReplacementDetails( - full=image_tokens + [bos_token_id], + full=replacement_tokens, features=image_tokens, ), ) ] - def _apply_prompt_replacements( + def apply( self, - token_ids: list[int], - mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], - mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - tokenizer = self.info.get_tokenizer() - - mm_token_matches = { - modality: find_token_matches(token_ids, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } - - # If the search text does not represent a special token, - # it may have different token IDs in the prompt, because - # the tokens may go across the boundaries of the search text. - # ---- - # e.g. when searching for "foo" in "food", if "food" itself makes - # up a token, then the token ID of "foo" will not appear at all - # ---- - # Since it is inefficient to search for all possible tokenizations - # of the search text in the prompt, we instead perform string - # replacement on the decoded token IDs, then encode them back. - if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = replace_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, - ) - - text = decode_tokens(tokenizer, token_ids) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } - else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, prompt_repls) - for modality, prompt_repls in mm_prompt_repls.items() - } - text = replace_text_matches( - text, - mm_text_matches, - mm_item_counts, - ) - - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_repls = { - modality: [match.prompt_repl for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } - - placeholders = self._find_mm_placeholders( - matched_repls, - token_ids, - mm_item_counts, - ) - - # Force to add newline at the end of prompt due to paligemma's format - if len(token_ids) and token_ids[-1] != 109: - token_ids.append(109) - text += "\n" - - return token_ids, text, placeholders + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputs: + multimodelinputs = super().apply(prompt, mm_data, + hf_processor_mm_kwargs) + # Force to add newline at the end of prompt according to paligemma's format + prompt_token_ids = multimodelinputs["prompt_token_ids"] + if len(prompt_token_ids) and prompt_token_ids[-1] != 108: + prompt_token_ids.append(108) + multimodelinputs["prompt_token_ids"] = prompt_token_ids + multimodelinputs["prompt"] += "\n" + return multimodelinputs @MULTIMODAL_REGISTRY.register_processor( From 670e5bf16b967922e6f6b17539a199f9bfd3ca09 Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 26 Feb 2025 09:40:12 -0800 Subject: [PATCH 07/10] use tokenizer to get newline token id (108) --- vllm/model_executor/models/paligemma.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 79c6f2410a74..c36a1621040f 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -199,16 +199,17 @@ def apply( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputs: - multimodelinputs = super().apply(prompt, mm_data, - hf_processor_mm_kwargs) - # Force to add newline at the end of prompt according to paligemma's format - prompt_token_ids = multimodelinputs["prompt_token_ids"] - if len(prompt_token_ids) and prompt_token_ids[-1] != 108: - prompt_token_ids.append(108) - multimodelinputs["prompt_token_ids"] = prompt_token_ids - multimodelinputs["prompt"] += "\n" - return multimodelinputs + mm_inputs = super().apply(prompt, mm_data,hf_processor_mm_kwargs) + prompt_token_ids = mm_inputs["prompt_token_ids"] + newline_prompt = "\n" + newline_token_id = self.info.get_tokenizer().encode(newline_prompt)[-1] # 108 + # Force to add newline at the end of prompt according to paligemma's format + if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: + prompt_token_ids.append(newline_token_id) + mm_inputs["prompt_token_ids"] = prompt_token_ids + mm_inputs["prompt"] += newline_prompt + return mm_inputs @MULTIMODAL_REGISTRY.register_processor( PaliGemmaMultiModalProcessor, From 3935f21da2d1378901f50f1da9fab48f7ee97db4 Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 26 Feb 2025 20:09:51 -0800 Subject: [PATCH 08/10] add google/paligemma models Signed-off-by: Kyle Huang --- tests/models/multimodal/processing/test_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 094d7afdc95a..04688ba69d8e 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -172,7 +172,8 @@ def _test_processing_correctness( "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", - "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "fixie-ai/ultravox-v0_4", + "openai/whisper-large-v3", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", ]) From bd93e4463bd10929243bc9f1f2efdbd5a315f04d Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 26 Feb 2025 20:21:24 -0800 Subject: [PATCH 09/10] remove blank Signed-off-by: Kyle Huang --- tests/models/multimodal/processing/test_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0bc2c1dbc6e3..9a009d403493 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -176,7 +176,6 @@ def _test_processing_correctness( "openai/whisper-large-v3", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", - ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) From 70633e0881070c587d44296eecba0443736b467f Mon Sep 17 00:00:00 2001 From: Kyle Huang Date: Wed, 26 Feb 2025 21:08:17 -0800 Subject: [PATCH 10/10] formatting Signed-off-by: Kyle Huang --- vllm/model_executor/models/paligemma.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 93e2ab55839a..ca85a2635750 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -122,8 +122,8 @@ def _call_hf_processor( # HF processor always adds placeholders even when there's no image tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode(prompt) - # Paligemma2 is NOT adding token at the beginning of the prompt - # Adding token (tokenizer.bos_token_id) to adapt with prompt replacement + # Paligemma2 is NOT adding token at the beginning + # Adding (tokenizer.bos_token_id) for prompt replacement if len(prompt_ids) == 0: prompt_ids = [tokenizer.bos_token_id] elif prompt_ids[0] != tokenizer.bos_token_id: @@ -148,7 +148,8 @@ def _apply_hf_processor_tokens_only( prompt_tokens: list[int], ) -> list[int]: tokenizer = self.info.get_tokenizer() - # Adding token (tokenizer.bos_token_id) to match with _call_hf_processor + # Adding (tokenizer.bos_token_id) to match with + # self._call_hf_processor if len(prompt_tokens) == 0: tokens_with_bos = [tokenizer.bos_token_id] elif prompt_tokens[0] != tokenizer.bos_token_id: @@ -174,8 +175,8 @@ def _get_prompt_replacements( assert isinstance(bos_token_id, int) # Adding token at the beginning based on add_bos_token variable. - # Always adding token at the end of the images tokens and before the text prompt - # according to Paligemma's format + # Always adding token at the end of the images tokens + # and before the text prompt according to Paligemma's format if tokenizer.add_bos_token: replacement_tokens = [bos_token_id] + image_tokens + [bos_token_id] else: @@ -198,18 +199,20 @@ def apply( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputs: - mm_inputs = super().apply(prompt, mm_data,hf_processor_mm_kwargs) + mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs) prompt_token_ids = mm_inputs["prompt_token_ids"] - newline_prompt = "\n" - newline_token_id = self.info.get_tokenizer().encode(newline_prompt)[-1] # 108 - # Force to add newline at the end of prompt according to paligemma's format + tokenizer = self.info.get_tokenizer() + newline_prompt = "\n" + newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108 + # Force to add newline at the end of prompt for paligemma's format if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: prompt_token_ids.append(newline_token_id) mm_inputs["prompt_token_ids"] = prompt_token_ids mm_inputs["prompt"] += newline_prompt return mm_inputs + @MULTIMODAL_REGISTRY.register_processor( PaliGemmaMultiModalProcessor, info=PaliGemmaProcessingInfo,