Skip to content

Conversation

killTheHostage
Copy link

@killTheHostage killTheHostage commented Sep 25, 2025

Purpose

This PR adds a new model structure: llavaonevision 1.5.
The corresponding HuggingFace model address is: https://huggingface.co/lmms-lab/LLaVA-OneVision-1.5-8B-Instruct
The corresponding GitHub repository is: https://github.com/EvolvingLMMs-Lab/LLaVA-OneVision-1.5
The associated HuggingFace commit address is: huggingface/transformers#41095

Test Plan

Serving test

To start the server, use the command: vllm serve ./llavaonevision1_5_8B, which defaults to using v1 mode and a single GPU. Then, use curl to make API requests. The request command is as follows:

curl --location --request POST 'http://127.0.0.1:8000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--data-raw '{
   "model": "/home/ubuntu/vllm_test/LLaVAOneVision1_5-8B/",
   "messages": [
      {
         "role": "user",
         "content": [
            {
               "type": "text",
               "text": "what is this in this picture? please tell me more"
            },
            {
               "type": "image_url",
               "image_url": {
                  "url": "http://127.0.0.1:8080"
               }
            }
         ]
      }
   ],
   "max_tokens": 1280,
   "temperature": 0,
   "stream": false
}'

In the statement, http://127.0.0.1:8080 is an image server written using Flask. The image is as follows:
6d1371fdd84054da8cd69ff55a65e855
I also tested vllm serve ./llavaonevision1_5_8B --tensor-parallel 2 using the V1 mode with two GPUs.

Offline test

Use the updated offline demo from this PR for testing, with the file path being examples/offline_inference/vision_language.py.However, some modifications were made for local testing, and the modified code is as follows:

# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
import random
from dataclasses import asdict
from typing import NamedTuple, Optional
from PIL import Image

from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser

os.environ['VLLM_USE_V1'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: list[str]
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None

def run_llavaov1_5(questions: list[str], modality: str) -> ModelRequestData:

    model_name = LOCAL_MODEL_PATH
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=1,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
        mm_processor_kwargs={
            "min_pixels": 4 * 28 * 28,
            "max_pixels": 16384 * 28 * 28,
        },
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )



model_example_map = {
    "llavaov1_5": run_llavaov1_5,
}


def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = Image.open("/gemini/data-1/jieyu/vllm_test/test.png").convert("RGB")
        img_questions = [
            "What is the content of this image?"
        ]

        return {
            "data": image,
            "questions": img_questions,
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
        vid_questions = ["Why is this video funny?"]

        return {
            "data": video,
            "questions": vid_questions,
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


def apply_image_repeat(image_repeat_prob, num_prompts, data,
                       prompts: list[str], modality):
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
            "prompt": prompts[i % len(prompts)],
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
    questions = mm_input["questions"]

    req_data = model_example_map[model](questions, modality)

    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    # To maintain code compatibility in this script, we add LoRA here.
    # You can also add LoRA using:
    # llm.generate(prompts, lora_request=lora_request,...)
    if req_data.lora_requests:
        for lora_request in req_data.lora_requests:
            llm.llm_engine.add_lora(lora_request=lora_request)

    # Don't want to check the flag multiple times, so just hijack `prompts`.
    prompts = req_data.prompts if args.use_different_prompt_per_request else [
        req_data.prompts[0]
    ]

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
                                     stop_token_ids=req_data.stop_token_ids)

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
            "prompt": prompts[0],
            "multi_modal_data": {
                modality: data
            },
            "prompt": '<|im_start|>user <image>\n请详细描述一下图片内容!<|im_end|>         <|im_start|>assistant\n',
            "multi_modal_data": {
                modality: data
            },
        }
    else:
        # Batch inference
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
                                        args.num_prompts, data, prompts,
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
                "prompt": prompts[i % len(prompts)],
                "multi_modal_data": {
                    modality: data
                },
            } for i in range(args.num_prompts)]

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))

    else:
        # assert 1==2, f'llm:{llm}'
        outputs = llm.generate(inputs, sampling_params=sampling_params)

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
        'vision language models for text generation')
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llavaov1_5",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
                        default=4,
                        help='Number of prompts to run.')
    parser.add_argument('--modality',
                        type=str,
                        default="image",
                        choices=['image', 'video'],
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help="Set the seed when initializing `vllm.LLM`.")

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
        '--disable-mm-preprocessor-cache',
        action='store_true',
        help='If True, disables caching of multi-modal preprocessor/mapper.')

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

    parser.add_argument(
        '--use-different-prompt-per-request',
        action='store_true',
        help='If True, then use different prompt (with the same multi-modal '
        'data) for each request.')

    args = parser.parse_args()
    main(args)

I used the offline infer to test the compatibility of v0.

Test Result

The image captures a scene with a tower in the background and cherry blossom trees in the foreground.
The image captures a beautiful scene of cherry blossoms in full bloom, with a tall tower in the background. The blossoms are pink and appear to be in the foreground, with the tower visible through the branches. The sky is clear and blue, creating a serene atmosphere.
The image captures a beautiful scene of a tower, possibly a communication or observation tower, partially obscured by the branches of a tree adorned with pink blossoms. The blossoms suggest that the photo was taken during springtime. The tower is tall and slender, with a conical top, and it stands out against the clear
The image captures a beautiful scene of a tower, possibly a communication or observation tower, partially obscured by branches of a tree adorned with pink blossoms. The blossoms suggest that the photo was taken during springtime. The tower's design is modern, with a cylindrical shape and a pointed top, which is a common architectural

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models labels Sep 25, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the LLAVA ONEVISION 1.5 model. The changes include the model implementation, an example for offline inference, and updates to documentation and the model registry. The implementation looks mostly good, but I've identified a critical correctness issue in the video frame padding logic, a significant performance issue in the vision attention mechanism due to redundant distributed operations, and a typo in the model's module name that affects consistency. Addressing these points will improve the correctness, performance, and maintainability of the new model support.

Comment on lines +655 to +657
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
padded_num_frames = num_frames + num_frames % temporal_patch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The padding logic for padded_num_frames is incorrect for LlavaOnevision1_5. The current implementation num_frames + num_frames % temporal_patch_size is copied from Qwen2VLImageProcessor and does not correctly pad the frames to be divisible by temporal_patch_size. The LlavaOnevision1_5ImageProcessor overrides this with the correct logic. This bug can lead to incorrect calculations for the number of vision tokens.

The correct logic should pad the number of frames to the next multiple of temporal_patch_size if it's not already a multiple. I've also updated the comment to point to the correct implementation.

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/pull/41095/files#diff-974f33e383120f288a5003517042578059639561a842790263503943399a0170R105
        padded_num_frames = num_frames + (temporal_patch_size - num_frames % temporal_patch_size) % temporal_patch_size

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The processing logic of LlavaOnevision1_5ImageProcessor is not significantly different from that of Qwen2VLImageProcessor, so the code was directly copied.

Comment on lines +248 to +264
def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor]:
seq_len, _ = qkv.shape
if self.tp_size > 1:
qkv = self._all_gather_tensor(qkv, self.qkv.hidden_size,
self.tp_size)
qkv = qkv.reshape(qkv.shape[0], 1, -1)
q, k, v = qkv.chunk(3, dim=2)
if self.tp_size > 1:
splitter = partial(dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
new_shape = (seq_len, self.num_attn_heads_per_partition,
self.hidden_size_per_attn_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _split_qkv method is unnecessarily complex and inefficient. It performs an all_gather operation (via _all_gather_tensor) to reconstruct the full QKV tensor and then splits it again, which is redundant because QKVParallelLinear already provides sharded outputs. This introduces unnecessary communication overhead.

This can be simplified to directly chunk and reshape the sharded qkv tensor. The _all_gather_tensor method at lines 228-246 can then be removed.

    def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        seq_len, _ = qkv.shape
        q, k, v = qkv.chunk(3, dim=-1)
        new_shape = (seq_len, self.num_attn_heads_per_partition,
                     self.hidden_size_per_attn_head)
        q = q.contiguous().view(*new_shape)
        k = k.contiguous().view(*new_shape)
        v = v.contiguous().view(*new_shape)
        return q, k, v

"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"LlavaOnevision1_5_ForConditionalGeneration": ("llava_onevison1_5", "LlavaOnevision1_5_ForConditionalGeneration"), # noqa: E501
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a typo in the module name: llava_onevison1_5 should be llava_onevision1_5. This should be corrected here, and the corresponding model file vllm/model_executor/models/llava_onevison1_5.py should be renamed to vllm/model_executor/models/llava_onevision1_5.py for consistency.

Suggested change
"LlavaOnevision1_5_ForConditionalGeneration": ("llava_onevison1_5", "LlavaOnevision1_5_ForConditionalGeneration"), # noqa: E501
"LlavaOnevision1_5_ForConditionalGeneration": ("llava_onevision1_5", "LlavaOnevision1_5_ForConditionalGeneration"), # noqa: E501

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you address this?

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this model, can you also add it to tests/models/registry.py and tests/models/multimodal/generation/test_common.py to check the correctness?

@DarkLight1337 DarkLight1337 added the multi-modality Related to multi-modality (#4194) label Sep 25, 2025
@killTheHostage
Copy link
Author

killTheHostage commented Oct 1, 2025

tests/models/multimodal/generation/test_common.py

Hello, during the execution of pytest on tests/models/multimodal/generation/test_common.py, I have identified an issue: the current HuggingFace implementation does not yet support LLaVA-OneVision-1.5, and the corresponding pull request is still under review, which causes the test to fail. Besides waiting for official HuggingFace support, are there any alternative approaches available?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants