From 762b092be80025fd2a8e59392a263184e03f8efd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 4 Jan 2025 16:39:59 +0000 Subject: [PATCH 1/7] Fix single-image input for Pixtral Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d522378e0beb..075dc5bac30f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -469,6 +469,10 @@ def sampler(self): return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + # The image size may be different for Pixtral-HF + if self.config.vision_config.model_type == "pixtral": + return data + h = w = self.config.vision_config.image_size expected_dims = (3, h, w) actual_dims = tuple(data.shape[1:]) From c4d6836ef67708402383bd2c7423e9a8c3e68efa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 11:10:11 +0000 Subject: [PATCH 2/7] Fix wrong patch size Signed-off-by: DarkLight1337 --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 9e1d38512c0b..b74bb3c8a3f8 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -774,7 +774,7 @@ def get_num_image_tokens( ) -> int: return get_pixtral_hf_image_feature_size( image_size=self.vision_config.image_size, - patch_size=self.get_image_size(), + patch_size=self.vision_config.patch_size, ) def get_max_image_tokens(self) -> int: From 7d394b58f17ddfcc9ddf136e45661b3109d22922 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 13:19:07 +0000 Subject: [PATCH 3/7] Add and update multi-image examples Signed-off-by: DarkLight1337 --- ...e_inference_vision_language_multi_image.py | 41 ++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 6af8d7768e75..8ba145fe2947 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -23,7 +23,7 @@ class ModelRequestData(NamedTuple): llm: LLM prompt: str - stop_token_ids: Optional[List[str]] + stop_token_ids: Optional[List[int]] image_data: List[Image] chat_template: Optional[str] @@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData: prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" "<|im_start|>assistant\n") stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return ModelRequestData( llm=llm, prompt=prompt, stop_token_ids=stop_token_ids, image_data=[fetch_image(url) for url in image_urls], - chat_template=None) + chat_template=None, + ) def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: @@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: limit_mm_per_prompt={"image": len(image_urls)}, ) - prompt = f"<|image|><|image|><|begin_of_text|>{question}" + placeholders = "<|image|>" * len(image_urls) + prompt = f"{placeholders}<|begin_of_text|>{question}" return ModelRequestData( llm=llm, prompt=prompt, @@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]): ) +def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData: + model_name = "mistral-community/pixtral-12b" + + # Adjust this as necessary to fit in GPU + llm = LLM( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + tensor_parallel_size=2, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "[IMG]" * len(image_urls) + prompt = f"[INST]{question}\n{placeholders}[/INST]" + stop_token_ids = None + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: # num_crops is an override kwarg to the multimodal image processor; # For some models, e.g., Phi-3.5-vision-instruct, it is recommended @@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: ) -def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: +def load_qwen_vl_chat(question: str, + image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, @@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return ModelRequestData( llm=llm, prompt=prompt, @@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: "mllama": load_mllama, "NVLM_D": load_nvlm_d, "phi3_v": load_phi3v, - "qwen_vl_chat": load_qwenvl_chat, + "pixtral_hf": load_phi3v, + "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, } From aac372e2a68ff00ca21006c901856487928dd43c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 13:22:15 +0000 Subject: [PATCH 4/7] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 9 ++++++--- vllm/model_executor/models/utils.py | 9 +++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d4e6406589e9..7ea3d78c4b8e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -31,6 +31,7 @@ ProcessingMixin, PromptReplacement) from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP @@ -521,7 +522,7 @@ def sampler(self): return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - # The image size may be different for Pixtral-HF + # Only the longest edge is equal to image_size for Pixtral-HF if self.config.vision_config.model_type == "pixtral": return data @@ -550,10 +551,12 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + pixel_values = flatten_bn(pixel_values, + concat=is_list_of(pixel_values, list)) + return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + data=self._validate_pixel_values(pixel_values), ) if image_embeds is not None: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 31017f16d3c9..4ed3b237ae0e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -281,6 +281,15 @@ def flatten_bn( ... +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + ... + + def flatten_bn( x: Union[List[torch.Tensor], torch.Tensor], *, From 77e05886dd438e053363e192e5ef41e90c95c22c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 13:41:21 +0000 Subject: [PATCH 5/7] Fix wrong function Signed-off-by: DarkLight1337 --- examples/offline_inference_vision_language_multi_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 8ba145fe2947..cf2e90a325c6 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -378,7 +378,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: "mllama": load_mllama, "NVLM_D": load_nvlm_d, "phi3_v": load_phi3v, - "pixtral_hf": load_phi3v, + "pixtral_hf": load_pixtral_hf, "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, } From ac26cd29c572a73fab6af91095ec460a23a76045 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 13:47:15 +0000 Subject: [PATCH 6/7] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7ea3d78c4b8e..3a10c4469794 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -31,7 +31,6 @@ ProcessingMixin, PromptReplacement) from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP @@ -521,11 +520,18 @@ def sampler(self): return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + def _validate_pixel_values( + self, + data: Union[torch.Tensor, List[torch.Tensor]], + ) -> Union[torch.Tensor, List[torch.Tensor]]: # Only the longest edge is equal to image_size for Pixtral-HF if self.config.vision_config.model_type == "pixtral": return data + if not isinstance(data, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(data)}") + h = w = self.config.vision_config.image_size expected_dims = (3, h, w) actual_dims = tuple(data.shape[1:]) @@ -551,12 +557,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, - concat=is_list_of(pixel_values, list)) - return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=self._validate_pixel_values(flatten_bn(pixel_values)), ) if image_embeds is not None: From 23dfc45de7edf3d6288211384dbcb2d6ff4e9191 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 7 Jan 2025 17:28:09 +0000 Subject: [PATCH 7/7] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3a10c4469794..305f1364dba2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -520,18 +520,7 @@ def sampler(self): return get_sampler() - def _validate_pixel_values( - self, - data: Union[torch.Tensor, List[torch.Tensor]], - ) -> Union[torch.Tensor, List[torch.Tensor]]: - # Only the longest edge is equal to image_size for Pixtral-HF - if self.config.vision_config.model_type == "pixtral": - return data - - if not isinstance(data, torch.Tensor): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(data)}") - + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) actual_dims = tuple(data.shape[1:]) @@ -557,9 +546,16 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + if self.config.vision_config.model_type == "pixtral": + return LlavaImagePixelInputs( + type="pixel_values", + data=flatten_bn(pixel_values), + ) + return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(flatten_bn(pixel_values)), + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), ) if image_embeds is not None: