Skip to content

Commit eb28242

Browse files
Add MLlama fast image processor (#41391)
* Merge conflict * add fast processor * add fast processor * make style * add new convert rgb * use nested group by shape in mllama fast, add support for multiple inputs in group by shape * refactor after review --------- Co-authored-by: Vincent <[email protected]>
1 parent 65cb8fa commit eb28242

File tree

10 files changed

+715
-276
lines changed

10 files changed

+715
-276
lines changed

docs/source/en/model_doc/mllama.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ processor = AutoProcessor.from_pretrained(model_id)
6767
messages = [
6868
[
6969
{
70-
"role": "user",
70+
"role": "user",
7171
"content": [
7272
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
7373
{"type": "text", "text": "What does the image show?"}
@@ -113,6 +113,10 @@ print(processor.decode(output[0], skip_special_tokens=True))
113113

114114
[[autodoc]] MllamaImageProcessor
115115

116+
## MllamaImageProcessorFast
117+
118+
[[autodoc]] MllamaImageProcessorFast
119+
116120
## MllamaForConditionalGeneration
117121

118122
[[autodoc]] MllamaForConditionalGeneration

src/transformers/image_processing_utils_fast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,19 @@ def is_fast(self) -> bool:
221221

222222
def pad(
223223
self,
224-
images: "torch.Tensor",
224+
images: list["torch.Tensor"],
225225
pad_size: SizeDict = None,
226226
fill_value: Optional[int] = 0,
227227
padding_mode: Optional[str] = "constant",
228228
return_mask: bool = False,
229229
disable_grouping: Optional[bool] = False,
230230
**kwargs,
231-
) -> "torch.Tensor":
231+
) -> Union[tuple["torch.Tensor", "torch.Tensor"], "torch.Tensor"]:
232232
"""
233233
Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.
234234
235235
Args:
236-
images (`torch.Tensor`):
236+
images (`list[torch.Tensor]`):
237237
Images to pad.
238238
pad_size (`SizeDict`, *optional*):
239239
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
@@ -248,7 +248,7 @@ def pad(
248248
Whether to disable grouping of images by size.
249249
250250
Returns:
251-
`torch.Tensor`: The resized image.
251+
`Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]`: The padded images and pixel masks if `return_mask` is `True`.
252252
"""
253253
if pad_size is not None:
254254
if not (pad_size.height and pad_size.width):

src/transformers/image_transforms.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -797,25 +797,61 @@ def flip_channel_order(
797797
return image
798798

799799

800+
def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_width: int) -> "torch.Tensor":
801+
# Split image into number of required tiles (width x height)
802+
batch_size, num_channels, height, width = images.size()
803+
images = images.view(
804+
batch_size,
805+
num_channels,
806+
num_tiles_height,
807+
height // num_tiles_height,
808+
num_tiles_width,
809+
width // num_tiles_width,
810+
)
811+
# Permute dimensions to reorder the axes
812+
image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
813+
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
814+
image = image.view(
815+
batch_size,
816+
num_tiles_width * num_tiles_height,
817+
num_channels,
818+
height // num_tiles_height,
819+
width // num_tiles_width,
820+
)
821+
return image
822+
823+
800824
def _cast_tensor_to_float(x):
801825
if x.is_floating_point():
802826
return x
803827
return x.float()
804828

805829

806-
def _group_images_by_shape(nested_images, is_nested: bool = False):
807-
"""Helper function to flatten a single level of nested image structures and group by shape."""
830+
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
831+
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
808832
grouped_images = defaultdict(list)
809833
grouped_images_index = {}
810-
nested_images = [nested_images] if not is_nested else nested_images
811-
for i, sublist in enumerate(nested_images):
812-
for j, image in enumerate(sublist):
834+
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
835+
836+
# Normalize inputs to consistent nested structure
837+
normalized_images = [nested_images] if not is_nested else nested_images
838+
normalized_paired = []
839+
for paired_input in paired_inputs:
840+
normalized_paired.append([paired_input] if not is_nested else paired_input)
841+
842+
# Process each image and group by shape
843+
for i, (sublist, *paired_sublists) in enumerate(zip(normalized_images, *normalized_paired)):
844+
for j, (image, *paired_values) in enumerate(zip(sublist, *paired_sublists)):
813845
key = (i, j) if is_nested else j
814846
shape = image.shape[1:]
847+
848+
# Add to grouped structures
815849
grouped_images[shape].append(image)
850+
for paired_index, paired_value in enumerate(paired_values):
851+
paired_grouped_values[paired_index][shape].append(paired_value)
816852
grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
817853

818-
return grouped_images, grouped_images_index
854+
return grouped_images, *paired_grouped_values, grouped_images_index
819855

820856

821857
def _reconstruct_nested_structure(indices, processed_images):
@@ -844,13 +880,35 @@ def _reconstruct_nested_structure(indices, processed_images):
844880
return result
845881

846882

883+
def _disable_grouping_output_nested(images, *paired_inputs):
884+
"""Build the disable_grouping output tuple for a single-level nested structure."""
885+
outer_range = range(len(images))
886+
inner_ranges = [range(len(images[i])) for i in outer_range]
887+
888+
# Precompute all (i, j) pairs
889+
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
890+
891+
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
892+
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
893+
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
894+
return images_dict, *paired_dicts, index_map
895+
896+
897+
def _disable_grouping_output_flat(images, *paired_inputs):
898+
"""Build the disable_grouping output tuple for a flat list structure."""
899+
idx_range = range(len(images))
900+
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
901+
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
902+
index_map = {i: (i, 0) for i in idx_range}
903+
return images_dict, *paired_dicts, index_map
904+
905+
847906
def group_images_by_shape(
848907
images: Union[list["torch.Tensor"], "torch.Tensor"],
849-
disable_grouping: bool,
908+
*paired_inputs,
909+
disable_grouping: Optional[bool],
850910
is_nested: bool = False,
851-
) -> tuple[
852-
dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]
853-
]:
911+
) -> tuple[dict, ...]:
854912
"""
855913
Groups images by shape.
856914
Returns a dictionary with the shape as key and a list of images with that shape as value,
@@ -862,15 +920,22 @@ def group_images_by_shape(
862920
Args:
863921
images (Union[list["torch.Tensor"], "torch.Tensor"]):
864922
A list of images or a single tensor
923+
*paired_inputs (Any):
924+
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
925+
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
926+
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
927+
they do not need to be tensors.
865928
disable_grouping (bool):
866929
Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
867930
This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
868931
is_nested (bool, *optional*, defaults to False):
869932
Whether the images are nested.
870933
871934
Returns:
872-
tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
873-
- A dictionary with shape as key and list of images with that shape as value
935+
tuple[dict, ...]:
936+
- A dictionary with shape as key and list/batch of images with that shape as value
937+
- Zero or more dictionaries (one per argument in `*paired_inputs`) grouped consistently with `images`; these carry
938+
the corresponding per-item values and are not stacked
874939
- A dictionary mapping original indices to (shape, index) tuples
875940
"""
876941
# If disable grouping is not explicitly provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
@@ -880,19 +945,19 @@ def group_images_by_shape(
880945

881946
if disable_grouping:
882947
if is_nested:
883-
return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, {
884-
(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))
885-
}
948+
return _disable_grouping_output_nested(images, *paired_inputs)
886949
else:
887-
return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))}
950+
return _disable_grouping_output_flat(images, *paired_inputs)
888951

889952
# Handle single level nested structure
890-
grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested)
953+
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
954+
images, *paired_inputs, is_nested=is_nested
955+
)
891956

892957
# Stack images with the same shape
893958
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
894959

895-
return grouped_images, grouped_images_index
960+
return grouped_images, *paired_grouped_values, grouped_images_index
896961

897962

898963
def reorder_images(

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
135135
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
136136
("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
137-
("mllama", ("MllamaImageProcessor", None)),
137+
("mllama", ("MllamaImageProcessor", "MllamaImageProcessorFast")),
138138
("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
139139
("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
140140
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),

src/transformers/models/idefics2/image_processing_idefics2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343

4444

4545
if is_vision_available():
46-
import PIL
4746
from PIL import Image
4847

4948

@@ -142,7 +141,7 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
142141
image (Image):
143142
The image to convert.
144143
"""
145-
if not isinstance(image, PIL.Image.Image):
144+
if not isinstance(image, Image.Image):
146145
return image
147146

148147
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background

src/transformers/models/llama4/image_processing_llama4_fast.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
group_images_by_shape,
2929
reorder_images,
3030
)
31+
from ...image_transforms import split_to_tiles
3132
from ...image_utils import ImageInput, PILImageResampling, SizeDict
3233
from ...processing_utils import ImagesKwargs, Unpack
3334
from ...utils import (
@@ -92,30 +93,6 @@ def get_max_res_without_distortion(
9293
return new_height, new_width
9394

9495

95-
def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor:
96-
# Split image into number of required tiles (width x height)
97-
batch_size, num_channels, height, width = images.size()
98-
images = images.view(
99-
batch_size,
100-
num_channels,
101-
num_tiles_height,
102-
height // num_tiles_height,
103-
num_tiles_width,
104-
width // num_tiles_width,
105-
)
106-
# Permute dimensions to reorder the axes
107-
image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
108-
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
109-
image = image.view(
110-
batch_size,
111-
num_tiles_width * num_tiles_height,
112-
num_channels,
113-
height // num_tiles_height,
114-
width // num_tiles_width,
115-
)
116-
return image
117-
118-
11996
@lru_cache(maxsize=1)
12097
def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor:
12198
"""

src/transformers/models/mllama/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if TYPE_CHECKING:
2121
from .configuration_mllama import *
2222
from .image_processing_mllama import *
23+
from .image_processing_mllama_fast import *
2324
from .modeling_mllama import *
2425
from .processing_mllama import *
2526
else:

src/transformers/models/mllama/image_processing_mllama.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343

4444

4545
if is_vision_available():
46-
import PIL
4746
from PIL import Image
4847

4948

@@ -407,30 +406,6 @@ def pack_images(
407406
return stacked_images, all_num_tiles
408407

409408

410-
def pack_aspect_ratios(aspect_ratios: list[list[tuple[int, int]]], pad_value: int = 1) -> np.ndarray:
411-
"""
412-
Stack a list of aspect ratios into a numpy array.
413-
414-
Args:
415-
aspect_ratios (`list[list[tuple[int, int]]]`):
416-
A list of aspect ratios.
417-
pad_value (`int`, *optional*, defaults to 1):
418-
The value to pad the aspect ratios with.
419-
420-
Returns:
421-
`np.ndarray`:
422-
The aspect ratios stacked into a numpy array with shape (batch_size, max_num_images, 2).
423-
"""
424-
batch_size = len(aspect_ratios)
425-
max_num_images = max(len(row) for row in aspect_ratios)
426-
427-
aspect_ratios_stacked = np.full((batch_size, max_num_images, 2), pad_value, dtype=np.int64)
428-
for i, row in enumerate(aspect_ratios):
429-
if len(row) > 0:
430-
aspect_ratios_stacked[i, : len(row)] = np.array(row)
431-
return aspect_ratios_stacked
432-
433-
434409
def convert_aspect_ratios_to_ids(aspect_ratios: list[list[tuple[int, int]]], max_image_tiles: int) -> np.ndarray:
435410
"""
436411
Convert aspect ratio tuples to unique ids.
@@ -511,7 +486,7 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
511486
image (Image):
512487
The image to convert.
513488
"""
514-
if not isinstance(image, PIL.Image.Image):
489+
if not isinstance(image, Image.Image):
515490
return image
516491

517492
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
@@ -718,7 +693,7 @@ def preprocess(
718693
# iterate over images in a batch sample
719694
for image in images:
720695
# default PIL images to channels_last
721-
if input_data_format is None and isinstance(image, PIL.Image.Image):
696+
if input_data_format is None and isinstance(image, Image.Image):
722697
input_data_format = ChannelDimension.LAST
723698

724699
# convert to numpy array for processing

0 commit comments

Comments
 (0)