44
55import math
66from collections .abc import Iterable , Mapping , Sequence
7+ from typing import Annotated , Literal
78
89import torch
910import torch .nn as nn
5354 count_tiles ,
5455)
5556from vllm .transformers_utils .tokenizer import cached_tokenizer_from_config
57+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
5658from vllm .v1 .sample .logits_processor import (
5759 AdapterLogitsProcessor ,
5860 RequestLogitsProcessor ,
6567_IMAGE_TOKEN = "<image>"
6668
6769
70+ class DeepseekOCRImagePixelInputs (TensorSchema ):
71+ """
72+ Dimensions:
73+ - b: Batch size
74+ - n: Number of images
75+ - p: Number of patches
76+ - base_size: Base size of the processor
77+ - image_size: Image size of the processor
78+ """
79+
80+ type : Literal ["pixel_values" ]
81+ data : Annotated [
82+ torch .Tensor ,
83+ TensorShape ("bn" , 3 , "base_size" , "base_size" , dynamic_dims = {"bnp" }),
84+ ]
85+ images_crop : Annotated [
86+ torch .Tensor ,
87+ TensorShape ("bnp" , 3 , "image_size" , "image_size" , dynamic_dims = {"bnp" }),
88+ ]
89+ images_spatial_crop : Annotated [torch .Tensor , TensorShape ("bn" , 2 )]
90+
91+
6892class NoRepeatNGramLogitsProcessor :
6993 def __init__ (
7094 self ,
@@ -260,10 +284,15 @@ def _get_mm_fields_config(
260284 hf_inputs : BatchFeature ,
261285 hf_processor_mm_kwargs : Mapping [str , object ],
262286 ) -> Mapping [str , MultiModalFieldConfig ]:
287+ images_spatial_crop = hf_inputs .get ("images_spatial_crop" , torch .empty ((0 , 2 )))
288+ is_tiled = (images_spatial_crop [:, 0 ] > 1 ) | (images_spatial_crop [:, 1 ] > 1 )
289+ patches_per_image = torch .where (is_tiled , images_spatial_crop .prod (dim = - 1 ), 0 )
263290 return dict (
264291 pixel_values = MultiModalFieldConfig .batched ("image" ),
265292 images_spatial_crop = MultiModalFieldConfig .batched ("image" ),
266- images_crop = MultiModalFieldConfig .batched ("image" ),
293+ images_crop = MultiModalFieldConfig .flat_from_sizes (
294+ "image" , patches_per_image
295+ ),
267296 )
268297
269298 def _get_prompt_updates (
@@ -302,42 +331,15 @@ def get_replacement_deepseek_vl2(item_idx: int):
302331 )
303332 ]
304333
305- # TODO(Isotr0py): Check if we still need this workaround for
306- # deepseek-ocr processor.
307- # def _cached_apply_hf_processor(
308- # self,
309- # prompt: str | list[int],
310- # mm_data_items: MultiModalDataItems,
311- # hf_processor_mm_kwargs: Mapping[str, object],
312- # tokenization_kwargs: Mapping[str, object],
313- # mm_uuids: MultiModalUUIDDict | None = None,
314- # ) -> tuple[list[int], MultiModalKwargs, bool]:
315- # # The processor logic is different for len(images) <= 2 vs > 2
316- # # Since the processing cache assumes that the processor output is
317- # # invariant of how many images are passed per prompt, we only
318- # # perform caching for the most common case
319- # if mm_data_items.get_count("image", strict=False) > 2:
320- # # This code path corresponds to the cache being disabled
321- # return self._apply_hf_processor_main(
322- # prompt=prompt,
323- # mm_items=mm_data_items,
324- # hf_processor_mm_kwargs=hf_processor_mm_kwargs,
325- # enable_hf_prompt_update=True,
326- # )
327-
328- # return super()._cached_apply_hf_processor(
329- # prompt=prompt,
330- # mm_data_items=mm_data_items,
331- # hf_processor_mm_kwargs=hf_processor_mm_kwargs,
332- # )
333-
334334
335335@MULTIMODAL_REGISTRY .register_processor (
336336 DeepseekOCRMultiModalProcessor ,
337337 info = DeepseekOCRProcessingInfo ,
338338 dummy_inputs = DeepseekOCRDummyInputsBuilder ,
339339)
340340class DeepseekOCRForCausalLM (nn .Module , SupportsMultiModal , SupportsPP ):
341+ merge_by_field_config = True
342+
341343 hf_to_vllm_mapper = WeightsMapper (
342344 orig_to_new_prefix = {
343345 # map prefix for language backbone
@@ -389,6 +391,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
389391 self .vision_model = DeepCLIPVisionTransformer (
390392 config = clip_vision_config ,
391393 quant_config = quant_config ,
394+ prefix = maybe_prefix (prefix , "vision_model" ),
392395 )
393396
394397 self .projector = MlpProjector (self .projector_config )
@@ -426,7 +429,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
426429 self .language_model .make_empty_intermediate_tensors
427430 )
428431
429- def _parse_and_validate_image_input (self , ** kwargs : object ):
432+ def _parse_and_validate_image_input (
433+ self , ** kwargs : object
434+ ) -> DeepseekOCRImagePixelInputs | None :
430435 pixel_values = kwargs .pop ("pixel_values" , None )
431436 images_spatial_crop = kwargs .pop ("images_spatial_crop" , None )
432437 images_crop = kwargs .pop ("images_crop" , None )
@@ -435,23 +440,16 @@ def _parse_and_validate_image_input(self, **kwargs: object):
435440 return None
436441
437442 if pixel_values is not None :
438- if not isinstance (pixel_values , (torch .Tensor , list )):
439- raise ValueError (
440- f"Incorrect type of pixel values. Got type: { type (pixel_values )} "
441- )
442-
443- if not isinstance (images_spatial_crop , (torch .Tensor , list )):
444- raise ValueError (
445- "Incorrect type of image sizes. "
446- f"Got type: { type (images_spatial_crop )} "
447- )
448-
449- if not isinstance (images_crop , (torch .Tensor , list )):
450- raise ValueError (
451- f"Incorrect type of image crop. Got type: { type (images_crop )} "
452- )
453-
454- return [pixel_values , images_crop , images_spatial_crop ]
443+ base_size = self .vision_config .image_size
444+ return DeepseekOCRImagePixelInputs (
445+ type = "pixel_values" ,
446+ data = pixel_values ,
447+ images_crop = images_crop ,
448+ images_spatial_crop = images_spatial_crop ,
449+ resolve_bindings = {
450+ "base_size" : base_size ,
451+ },
452+ )
455453
456454 raise AssertionError ("This line should be unreachable." )
457455
@@ -518,10 +516,13 @@ def _pixel_values_to_embedding(
518516 ) -> NestedTensors :
519517 images_in_this_batch = []
520518
519+ is_tiled = (images_spatial_crop [:, 0 ] > 1 ) | (images_spatial_crop [:, 1 ] > 1 )
520+ patches_per_image = torch .where (is_tiled , images_spatial_crop .prod (dim = - 1 ), 0 )
521+ images_crop = images_crop .split (patches_per_image .tolist ())
521522 for jdx in range (images_spatial_crop .size (0 )):
522- patches = images_crop [jdx ][ 0 ]. to ( torch . bfloat16 )
523- image_ori = pixel_values [jdx ]
524- crop_shape = images_spatial_crop [jdx ][ 0 ]
523+ patches = images_crop [jdx ]
524+ image_ori = pixel_values [[ jdx ] ]
525+ crop_shape = images_spatial_crop [jdx ]
525526
526527 global_features = self ._encode_global_features (image_ori )
527528 local_features = self ._encode_local_features (patches , crop_shape )
@@ -540,10 +541,12 @@ def _pixel_values_to_embedding(
540541
541542 return images_in_this_batch
542543
543- def _process_image_input (self , image_input ) -> torch .Tensor :
544- pixel_values = image_input [0 ].to (torch .bfloat16 )
545- images_crop = image_input [1 ]
546- images_spatial_crop = image_input [2 ].to (dtype = torch .long )
544+ def _process_image_input (
545+ self , image_input : DeepseekOCRImagePixelInputs
546+ ) -> torch .Tensor :
547+ pixel_values = image_input .data
548+ images_crop = image_input .images_crop
549+ images_spatial_crop = image_input .images_spatial_crop .to (dtype = torch .long )
547550
548551 vision_features = self ._pixel_values_to_embedding (
549552 pixel_values = pixel_values ,
0 commit comments