49
49
PromptInsertion , PromptUpdate )
50
50
from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
51
51
from vllm .sequence import IntermediateTensors
52
- from vllm .utils import flatten_2d_lists
53
52
54
53
from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
55
54
SupportsMultiModal , SupportsPP , SupportsQuant )
72
71
73
72
class MolmoImageInputs (TypedDict ):
74
73
images : Union [torch .Tensor , list [torch .Tensor ]]
75
- """Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
74
+ """Shape: `(batch_size * num_images , num_crops, num_patch, patch_dim)`"""
76
75
77
76
image_masks : Optional [Union [torch .Tensor , list [torch .Tensor ]]]
78
- """Shape: `(batch_size, num_crops, num_patch)`"""
77
+ """Shape: `(batch_size * num_images , num_crops, num_patch)`"""
79
78
80
79
feat_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
81
80
"""
82
81
A boolean mask indicating which image features correspond
83
82
to patch tokens.
84
83
85
- Shape: `(batch_size, num_crops, num_patch)`
84
+ Shape: `(batch_size * num_images , num_crops, num_patch)`
86
85
"""
87
86
88
87
embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
89
88
"""
90
89
A boolean mask indicating which image embeddings correspond
91
90
to patch tokens.
92
91
93
- Shape: `(batch_size, num_embeds)`
92
+ Shape: `(batch_size * num_images , num_embeds)`
94
93
"""
95
94
96
95
num_crops : Union [torch .Tensor , list [torch .Tensor ]]
@@ -696,9 +695,10 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor:
696
695
return image_features
697
696
698
697
def forward (
699
- self , images : torch .Tensor , image_masks : torch .Tensor
700
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
701
-
698
+ self ,
699
+ images : torch .Tensor ,
700
+ image_masks : torch .Tensor ,
701
+ ) -> torch .Tensor :
702
702
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
703
703
batch_size , num_image = images .shape [:2 ]
704
704
images = images .to (device = self .device , dtype = self .dtype )
@@ -1491,6 +1491,8 @@ def _parse_and_validate_image_input(
1491
1491
f"Got type: { type (img_patch_id )} " )
1492
1492
self .img_patch_id = img_patch_id .flatten ().unique ().item ()
1493
1493
1494
+ embed_is_patch = flatten_bn (embed_is_patch )
1495
+
1494
1496
return MolmoImageInputs (
1495
1497
images = images ,
1496
1498
image_masks = image_masks ,
@@ -1502,13 +1504,17 @@ def _parse_and_validate_image_input(
1502
1504
def _process_image_input (
1503
1505
self ,
1504
1506
image_input : MolmoImageInputs ,
1505
- ) -> Union [torch .Tensor , list [torch .Tensor ]]:
1506
- if isinstance (image_input ["images" ], list ):
1507
+ ) -> list [torch .Tensor ]:
1508
+ images = image_input ["images" ]
1509
+ image_masks = image_input ["image_masks" ]
1510
+ feat_is_patch = image_input ["feat_is_patch" ]
1511
+ num_crops = image_input ["num_crops" ]
1512
+
1513
+ if isinstance (images , list ):
1507
1514
# Call the vision backbone on the whole batch at once
1508
- images_flat = flatten_bn (image_input ["images" ], concat = True )
1509
- image_masks_flat = (None if (image_masks :=
1510
- image_input ["image_masks" ]) is None
1511
- else flatten_bn (image_masks , concat = True ))
1515
+ images_flat = flatten_bn (images , concat = True )
1516
+ image_masks_flat = (None if image_masks is None else flatten_bn (
1517
+ image_masks , concat = True ))
1512
1518
1513
1519
image_features_flat = self .vision_backbone (
1514
1520
images = images_flat .unsqueeze (0 ),
@@ -1517,63 +1523,19 @@ def _process_image_input(
1517
1523
).squeeze (0 )
1518
1524
1519
1525
# Reconstruct the batch dimension
1520
- image_features = image_features_flat . split (
1521
- image_input [ "num_crops" ]. sum ( - 1 ). tolist () )
1526
+ num_crops_per_image = [ nc . sum (). item () for nc in num_crops ]
1527
+ image_features = image_features_flat . split ( num_crops_per_image )
1522
1528
else :
1523
1529
image_features = self .vision_backbone (
1524
- images = image_input [ " images" ] ,
1525
- image_masks = image_input [ " image_masks" ] ,
1530
+ images = images ,
1531
+ image_masks = image_masks ,
1526
1532
)
1527
1533
1528
- return image_features
1529
-
1530
- def _get_mm_embeds (
1531
- self ,
1532
- features : torch .Tensor , # Shape: (num_crop, num_patch, d)
1533
- feat_is_patch : torch .Tensor , # Shape: (num_crop, num_patch)
1534
- num_crops : torch .Tensor , # Shape: (num_images,)
1535
- embed_is_patch : torch .Tensor , # Shape: (num_embeds,)
1536
- ) -> tuple [torch .Tensor , ...]:
1537
- """
1538
- Scatter the patch features into a contiguous tensor that corresponds
1539
- to the embedding tokens defined by the multimodal processor.
1540
-
1541
- Note:
1542
- The original code only considers patch tokens as feature
1543
- tokens, but our processor considers all image-related tokens
1544
- as feature tokens because the feature tokens need to be
1545
- consecutive in `input_ids`.
1546
-
1547
- Example:
1548
- A simplified example for one item in the batch:
1549
-
1550
- .. code-block::
1551
-
1552
- Embedding tokens (from HF processor):
1553
- [<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
1554
-
1555
- embed_is_patch (from HF processor):
1556
- [ False True True False True True False False ]
1557
-
1558
- Encoder outputs (from model):
1559
- [ p1 p2 0 p3 p4 0 ]
1560
-
1561
- feat_is_patch (from HF processor):
1562
- [ True True False True True False ]
1563
-
1564
- The resulting embedding tensor is:
1565
- [ nan p1 p2 nan p3 p4 nan nan ]
1566
- """
1567
- num_crops_per_image = num_crops .tolist ()
1568
- feats_per_image = features .split (num_crops_per_image )
1569
- f_is_patch_per_image = feat_is_patch .split (num_crops_per_image )
1570
-
1571
- features = torch .cat ([
1534
+ # Only the features corresponding to patch tokens are relevant
1535
+ return [
1572
1536
feats [f_is_patch ]
1573
- for feats , f_is_patch in zip (feats_per_image , f_is_patch_per_image )
1574
- ])
1575
-
1576
- return scatter_patch_features (features , embed_is_patch )
1537
+ for feats , f_is_patch in zip (image_features , feat_is_patch )
1538
+ ]
1577
1539
1578
1540
def get_multimodal_embeddings (
1579
1541
self , ** kwargs : object ) -> Optional [MultiModalEmbeddings ]:
@@ -1583,13 +1545,10 @@ def get_multimodal_embeddings(
1583
1545
1584
1546
image_features = self ._process_image_input (image_input )
1585
1547
1586
- return flatten_2d_lists (
1587
- self ._get_mm_embeds (* args ) for args in zip (
1588
- image_features ,
1589
- image_input ["feat_is_patch" ],
1590
- image_input ["num_crops" ],
1591
- image_input ["embed_is_patch" ],
1592
- ))
1548
+ return scatter_patch_features (
1549
+ image_features ,
1550
+ image_input ["embed_is_patch" ],
1551
+ )
1593
1552
1594
1553
def get_input_embeddings (
1595
1554
self ,
0 commit comments