Skip to content

Commit 2566dca

Browse files
authored
[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)
Signed-off-by: Isotr0py <[email protected]>
1 parent b4fda58 commit 2566dca

File tree

4 files changed

+112
-66
lines changed

4 files changed

+112
-66
lines changed

examples/offline_inference/vision_language_multi_image.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
4444
stop_token_ids: list[int] | None = None
4545
chat_template: str | None = None
4646
lora_requests: list[LoRARequest] | None = None
47+
sampling_params: SamplingParams | None = None
4748

4849

4950
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
201202
)
202203

203204

205+
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
206+
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
207+
208+
model_name = "deepseek-ai/DeepSeek-OCR"
209+
210+
engine_args = EngineArgs(
211+
model=model_name,
212+
max_num_seqs=2,
213+
limit_mm_per_prompt={"image": len(image_urls)},
214+
logits_processors=[NGramPerReqLogitsProcessor],
215+
)
216+
217+
placeholder = "<image>\n" * len(image_urls)
218+
prompt = placeholder + question
219+
220+
# The following sampling params config is taken from
221+
# the official Deepseek-OCR inference example.
222+
# (IMPORTANT) Use the custom logits processor and avoid skipping
223+
# special tokens for this model for the optimal OCR performance.
224+
sampling_params = SamplingParams(
225+
temperature=0.0,
226+
max_tokens=8192,
227+
# ngram logit processor args
228+
extra_args=dict(
229+
ngram_size=30,
230+
window_size=90,
231+
# whitelist: <td>, </td>
232+
whitelist_token_ids={128821, 128822},
233+
),
234+
skip_special_tokens=False,
235+
)
236+
237+
return ModelRequestData(
238+
engine_args=engine_args,
239+
prompt=prompt,
240+
image_data=[fetch_image(url) for url in image_urls],
241+
sampling_params=sampling_params,
242+
)
243+
244+
204245
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
205246
model_name = "google/gemma-3-4b-it"
206247

@@ -1253,6 +1294,7 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
12531294
"bee": load_bee,
12541295
"command_a_vision": load_command_a_vision,
12551296
"deepseek_vl_v2": load_deepseek_vl2,
1297+
"deepseek_ocr": load_deepseek_ocr,
12561298
"gemma3": load_gemma3,
12571299
"h2ovl_chat": load_h2ovl,
12581300
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
@@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
13251367
engine_args = asdict(req_data.engine_args) | {"seed": seed}
13261368
llm = LLM(**engine_args)
13271369

1328-
sampling_params = SamplingParams(
1329-
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
1370+
sampling_params = (
1371+
SamplingParams(
1372+
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
1373+
)
1374+
if req_data.sampling_params is None
1375+
else req_data.sampling_params
13301376
)
13311377
outputs = llm.chat(
13321378
[

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def _test_processing_correctness_one(
332332
"facebook/chameleon-7b",
333333
"CohereLabs/command-a-vision-07-2025",
334334
"deepseek-ai/deepseek-vl2-tiny",
335+
"deepseek-ai/DeepSeek-OCR",
335336
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
336337
"adept/fuyu-8b",
337338
"google/gemma-3-4b-it",

vllm/model_executor/models/deepseek_ocr.py

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import math
66
from collections.abc import Iterable, Mapping, Sequence
7+
from typing import Annotated, Literal
78

89
import torch
910
import torch.nn as nn
@@ -53,6 +54,7 @@
5354
count_tiles,
5455
)
5556
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
57+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5658
from vllm.v1.sample.logits_processor import (
5759
AdapterLogitsProcessor,
5860
RequestLogitsProcessor,
@@ -65,6 +67,28 @@
6567
_IMAGE_TOKEN = "<image>"
6668

6769

70+
class DeepseekOCRImagePixelInputs(TensorSchema):
71+
"""
72+
Dimensions:
73+
- b: Batch size
74+
- n: Number of images
75+
- p: Number of patches
76+
- base_size: Base size of the processor
77+
- image_size: Image size of the processor
78+
"""
79+
80+
type: Literal["pixel_values"]
81+
data: Annotated[
82+
torch.Tensor,
83+
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
84+
]
85+
images_crop: Annotated[
86+
torch.Tensor,
87+
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
88+
]
89+
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
90+
91+
6892
class NoRepeatNGramLogitsProcessor:
6993
def __init__(
7094
self,
@@ -260,10 +284,15 @@ def _get_mm_fields_config(
260284
hf_inputs: BatchFeature,
261285
hf_processor_mm_kwargs: Mapping[str, object],
262286
) -> Mapping[str, MultiModalFieldConfig]:
287+
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
288+
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
289+
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
263290
return dict(
264291
pixel_values=MultiModalFieldConfig.batched("image"),
265292
images_spatial_crop=MultiModalFieldConfig.batched("image"),
266-
images_crop=MultiModalFieldConfig.batched("image"),
293+
images_crop=MultiModalFieldConfig.flat_from_sizes(
294+
"image", patches_per_image
295+
),
267296
)
268297

269298
def _get_prompt_updates(
@@ -302,42 +331,15 @@ def get_replacement_deepseek_vl2(item_idx: int):
302331
)
303332
]
304333

305-
# TODO(Isotr0py): Check if we still need this workaround for
306-
# deepseek-ocr processor.
307-
# def _cached_apply_hf_processor(
308-
# self,
309-
# prompt: str | list[int],
310-
# mm_data_items: MultiModalDataItems,
311-
# hf_processor_mm_kwargs: Mapping[str, object],
312-
# tokenization_kwargs: Mapping[str, object],
313-
# mm_uuids: MultiModalUUIDDict | None = None,
314-
# ) -> tuple[list[int], MultiModalKwargs, bool]:
315-
# # The processor logic is different for len(images) <= 2 vs > 2
316-
# # Since the processing cache assumes that the processor output is
317-
# # invariant of how many images are passed per prompt, we only
318-
# # perform caching for the most common case
319-
# if mm_data_items.get_count("image", strict=False) > 2:
320-
# # This code path corresponds to the cache being disabled
321-
# return self._apply_hf_processor_main(
322-
# prompt=prompt,
323-
# mm_items=mm_data_items,
324-
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
325-
# enable_hf_prompt_update=True,
326-
# )
327-
328-
# return super()._cached_apply_hf_processor(
329-
# prompt=prompt,
330-
# mm_data_items=mm_data_items,
331-
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
332-
# )
333-
334334

335335
@MULTIMODAL_REGISTRY.register_processor(
336336
DeepseekOCRMultiModalProcessor,
337337
info=DeepseekOCRProcessingInfo,
338338
dummy_inputs=DeepseekOCRDummyInputsBuilder,
339339
)
340340
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
341+
merge_by_field_config = True
342+
341343
hf_to_vllm_mapper = WeightsMapper(
342344
orig_to_new_prefix={
343345
# map prefix for language backbone
@@ -389,6 +391,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
389391
self.vision_model = DeepCLIPVisionTransformer(
390392
config=clip_vision_config,
391393
quant_config=quant_config,
394+
prefix=maybe_prefix(prefix, "vision_model"),
392395
)
393396

394397
self.projector = MlpProjector(self.projector_config)
@@ -426,7 +429,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
426429
self.language_model.make_empty_intermediate_tensors
427430
)
428431

429-
def _parse_and_validate_image_input(self, **kwargs: object):
432+
def _parse_and_validate_image_input(
433+
self, **kwargs: object
434+
) -> DeepseekOCRImagePixelInputs | None:
430435
pixel_values = kwargs.pop("pixel_values", None)
431436
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
432437
images_crop = kwargs.pop("images_crop", None)
@@ -435,23 +440,16 @@ def _parse_and_validate_image_input(self, **kwargs: object):
435440
return None
436441

437442
if pixel_values is not None:
438-
if not isinstance(pixel_values, (torch.Tensor, list)):
439-
raise ValueError(
440-
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
441-
)
442-
443-
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
444-
raise ValueError(
445-
"Incorrect type of image sizes. "
446-
f"Got type: {type(images_spatial_crop)}"
447-
)
448-
449-
if not isinstance(images_crop, (torch.Tensor, list)):
450-
raise ValueError(
451-
f"Incorrect type of image crop. Got type: {type(images_crop)}"
452-
)
453-
454-
return [pixel_values, images_crop, images_spatial_crop]
443+
base_size = self.vision_config.image_size
444+
return DeepseekOCRImagePixelInputs(
445+
type="pixel_values",
446+
data=pixel_values,
447+
images_crop=images_crop,
448+
images_spatial_crop=images_spatial_crop,
449+
resolve_bindings={
450+
"base_size": base_size,
451+
},
452+
)
455453

456454
raise AssertionError("This line should be unreachable.")
457455

@@ -518,10 +516,13 @@ def _pixel_values_to_embedding(
518516
) -> NestedTensors:
519517
images_in_this_batch = []
520518

519+
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
520+
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
521+
images_crop = images_crop.split(patches_per_image.tolist())
521522
for jdx in range(images_spatial_crop.size(0)):
522-
patches = images_crop[jdx][0].to(torch.bfloat16)
523-
image_ori = pixel_values[jdx]
524-
crop_shape = images_spatial_crop[jdx][0]
523+
patches = images_crop[jdx]
524+
image_ori = pixel_values[[jdx]]
525+
crop_shape = images_spatial_crop[jdx]
525526

526527
global_features = self._encode_global_features(image_ori)
527528
local_features = self._encode_local_features(patches, crop_shape)
@@ -540,10 +541,12 @@ def _pixel_values_to_embedding(
540541

541542
return images_in_this_batch
542543

543-
def _process_image_input(self, image_input) -> torch.Tensor:
544-
pixel_values = image_input[0].to(torch.bfloat16)
545-
images_crop = image_input[1]
546-
images_spatial_crop = image_input[2].to(dtype=torch.long)
544+
def _process_image_input(
545+
self, image_input: DeepseekOCRImagePixelInputs
546+
) -> torch.Tensor:
547+
pixel_values = image_input.data
548+
images_crop = image_input.images_crop
549+
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
547550

548551
vision_features = self._pixel_values_to_embedding(
549552
pixel_values=pixel_values,

vllm/transformers_utils/processors/deepseek_ocr.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,20 +411,16 @@ def tokenize_with_images(
411411
images_seq_mask = images_seq_mask[:-1]
412412

413413
if len(images_list) == 0:
414-
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
415-
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
416-
images_crop = torch.zeros(
417-
(1, 3, self.image_size, self.image_size)
418-
).unsqueeze(0)
414+
pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
415+
images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
416+
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
419417
else:
420418
pixel_values = torch.stack(images_list, dim=0)
421419
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
422420
if images_crop_list:
423-
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
421+
images_crop = torch.stack(images_crop_list, dim=0)
424422
else:
425-
images_crop = torch.zeros(
426-
(1, 3, self.image_size, self.image_size)
427-
).unsqueeze(0)
423+
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
428424

429425
input_ids = input_ids.unsqueeze(0)
430426

0 commit comments

Comments
 (0)