-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Add MLlama fast image processor #41391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLlama fast image processor #41391
Conversation
…puts in group by shape
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! left a smol review for the image_transforms.py fix first! Will do the rest in a follow-up
src/transformers/image_transforms.py
Outdated
| return ( | ||
| {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, | ||
| *[ | ||
| { | ||
| (i, j): paired_list[i][j].unsqueeze(0) | ||
| for i in range(len(paired_list)) | ||
| for j in range(len(paired_list[i])) | ||
| } | ||
| for paired_list in paired_inputs | ||
| ], | ||
| {(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))}, | ||
| ) | ||
| else: | ||
| return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} | ||
| return ( | ||
| {i: images[i].unsqueeze(0) for i in range(len(images))}, | ||
| *[{i: paired_list[i].unsqueeze(0) for i in range(len(paired_list))} for paired_list in paired_inputs], | ||
| {i: (i, 0) for i in range(len(images))}, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to understand this after some time with pen&paper, it would be nice to rewrite a bit to have a clearer logic flow 😀
I would suggest writing another helper function, private, like build_ungrouped_outputs that unrolls this logic, and returns a tuple with the images dictionaries, the *unpacked paired dictionaries, and the index map.
Also for the dictionary iterators, we can iterate on keys established once (in the non-nested case) with keys = list(range(len(images))) to avoid having several range(len(...) calls
basically naming and moving these into another function, that has the same return value and return type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed this needs some cleaning up! And you're right we can optimize it a bit, and that's the whole point :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, let me know when you've done another pass!
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we need to refactor a bit the grouping fn or maybe we can reorder the code so we don't need any grouping. Up to you :)
| def _validate_size(size: SizeDict) -> None: | ||
| if not (size.height and size.width): | ||
| raise ValueError(f"Argument `size` must be a dictionary with keys 'height' and 'width'. Got: {size}") | ||
| if size.height != size.width: | ||
| raise ValueError(f"Argument `size` must have the same height and width, got {size}") | ||
|
|
||
|
|
||
| def _validate_mllama_preprocess_arguments(do_resize, size, do_pad, max_image_tiles): | ||
| if not do_pad: | ||
| raise ValueError("MllamaImageProcessor doesn't support `do_pad=False` mode.") | ||
| if not do_resize: | ||
| raise ValueError("MllamaImageProcessor doesn't support `do_resize=False` mode.") | ||
| if max_image_tiles is None or max_image_tiles <= 0: | ||
| raise ValueError(f"MllamaImageProcessor `max_image_tiles` must be a positive integer, got {max_image_tiles}.") | ||
| _validate_size(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, first time seeing custom validation for kwargs in image processing. I am merging #40793 today, maybe we can think of a cleaner way to validate kwargs with hub's validators later
Good thing is that hub validation runs at __setattr__ if we use dataclasses, but with typed dicts it is much more primitive currently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, should not be custom here - it is not standard, and it's low-level so should be abstracted away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not needed indeed! Overall I'm not against custom validation (maybe with warning instead of errors though), as some image processors have different constraints than others, so hard to abstract away. In that case, setting resize or pad to False will still resize and pad with no warning.
Also if I'm understanding #40793 correctly, this would only validate the kwargs when loading with from_pretrained? Or will it also work when adding kwargs to the processors call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hub validation currently checks only for type hints from the TypedDict in every processing call. For doing the check on every from_pretrained() we have to be sure that all hub configs are saved correctly, because it will raise errors otherwise. I don't want to break configs if they are serialized bad. My next idea is to add a global validation with all values after the per-field type hints (currently validate_kwargs and valdate_processor_arguments fn)
Also, we can add per-field custom validation if we add it in the metadata like my_field: Annotated[int, custom_valdation_fn()]. Though I just found out yesterday that TypedDict does not save these metadata and I am looking for a way to recover them back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yay for Annotated! https://docs.python.org/3/library/typing.html#typing.get_type_hints with include_extras=True might be what you're looking for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it neither worked. I found that it works when the metadata is a string, but not with callables. A workaround is to wrap the callable with hub utility, that worked in the past and was removed at some point when iterating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I understand, you were resolving it at load time with importlib or something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope, I am also using the get_type_hints with extras such as get_type_hints(ImagesKwargs, include_extras=True). Some quirks of Annotated ig, havent had chance to dig into the root reason yet
| # same aspect ratio for all images in the batch | ||
| num_tiles_height, num_tiles_width = grouped_aspect_ratios[shape][0] | ||
| stacked_images = split_to_tiles(stacked_images, num_tiles_height, num_tiles_width) | ||
| processed_images_grouped[shape] = stacked_images |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we can also do the splitting before "rescale" in the prev for-loop, since rescale is simply multiplying by a value and doesn't depend on image shape. That way we don't need to group aspect ratios together with images
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right my bad! It can be much simpler, and no need for to group aspect ratios together afterwards. However I think we could keep the option to group additional *args, as it will be useful for maskformer in this PR #41393
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some additional comments!
src/transformers/image_transforms.py
Outdated
| def group_images_by_shape( | ||
| images: Union[list["torch.Tensor"], "torch.Tensor"], | ||
| *paired_inputs, | ||
| disable_grouping: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but disable_grouping is a tri-state so should be Optional at least
src/transformers/image_transforms.py
Outdated
| return ( | ||
| {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, | ||
| *[ | ||
| { | ||
| (i, j): paired_list[i][j].unsqueeze(0) | ||
| for i in range(len(paired_list)) | ||
| for j in range(len(paired_list[i])) | ||
| } | ||
| for paired_list in paired_inputs | ||
| ], | ||
| {(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))}, | ||
| ) | ||
| else: | ||
| return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} | ||
| return ( | ||
| {i: images[i].unsqueeze(0) for i in range(len(images))}, | ||
| *[{i: paired_list[i].unsqueeze(0) for i in range(len(paired_list))} for paired_list in paired_inputs], | ||
| {i: (i, 0) for i in range(len(images))}, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, let me know when you've done another pass!
| def _validate_size(size: SizeDict) -> None: | ||
| if not (size.height and size.width): | ||
| raise ValueError(f"Argument `size` must be a dictionary with keys 'height' and 'width'. Got: {size}") | ||
| if size.height != size.width: | ||
| raise ValueError(f"Argument `size` must have the same height and width, got {size}") | ||
|
|
||
|
|
||
| def _validate_mllama_preprocess_arguments(do_resize, size, do_pad, max_image_tiles): | ||
| if not do_pad: | ||
| raise ValueError("MllamaImageProcessor doesn't support `do_pad=False` mode.") | ||
| if not do_resize: | ||
| raise ValueError("MllamaImageProcessor doesn't support `do_resize=False` mode.") | ||
| if max_image_tiles is None or max_image_tiles <= 0: | ||
| raise ValueError(f"MllamaImageProcessor `max_image_tiles` must be a positive integer, got {max_image_tiles}.") | ||
| _validate_size(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, should not be custom here - it is not standard, and it's low-level so should be abstracted away.
| num_channels=3, | ||
| image_size=18, | ||
| num_images=18, | ||
| num_images=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my information, why is this test config dropped to 1 image?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed this when I took over the PR, thanks for pointing it out!
| def convert_to_rgb( | ||
| self, | ||
| image: ImageInput, | ||
| ) -> ImageInput: | ||
| """ | ||
| Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image | ||
| as is. | ||
| Args: | ||
| image (ImageInput): | ||
| The image to convert. | ||
| Returns: | ||
| ImageInput: The converted image. | ||
| """ | ||
| return convert_to_rgb(image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to directly rely on the imported func
| for image_processing_class in self.image_processor_list: | ||
| image_processing = image_processing_class(**self.image_processor_dict) | ||
| self.assertTrue(hasattr(image_processing, "do_convert_rgb")) | ||
| self.assertTrue(hasattr(image_processing, "do_resize")) | ||
| self.assertTrue(hasattr(image_processing, "size")) | ||
| self.assertTrue(hasattr(image_processing, "do_rescale")) | ||
| self.assertTrue(hasattr(image_processing, "rescale_factor")) | ||
| self.assertTrue(hasattr(image_processing, "do_normalize")) | ||
| self.assertTrue(hasattr(image_processing, "image_mean")) | ||
| self.assertTrue(hasattr(image_processing, "image_std")) | ||
| self.assertTrue(hasattr(image_processing, "do_pad")) | ||
| self.assertTrue(hasattr(image_processing, "max_image_tiles")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking for a follow-up, we can automate these
| aspect_ratio_mask = torch.zeros((batch_size, max_num_images, max_image_tiles), dtype=torch.long) | ||
|
|
||
| # Set the first tile to 1 for all aspect ratios | ||
| # because in original implementation aspect ratios are padded with (1, 1), | ||
| # but original code examples are not built to handle batches, so we might remove it later | ||
| aspect_ratio_mask[:, :, 0] = 1 | ||
|
|
||
| # Set the aspect ratio mask for the rest of the tiles | ||
| for i, sample_aspect_ratios in enumerate(aspect_ratios): | ||
| for j, (num_tiles_w, num_tiles_h) in enumerate(sample_aspect_ratios): | ||
| aspect_ratio_mask[i, j, : num_tiles_w * num_tiles_h] = 1 | ||
|
|
||
| return aspect_ratio_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a couple double for loops in the util functions here, it's a minor optim but we could precompute + broadcast instead, would be more efficient especially for large batches IMO
|
Thanks @molbap and @zucchini-nlp! I addressed most of your remarks, it should be ready for another review :) |
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, left some open questions that are not blockers
| def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_width: int) -> "torch.Tensor": | ||
| # Split image into number of required tiles (width x height) | ||
| batch_size, num_channels, height, width = images.size() | ||
| images = images.view( | ||
| batch_size, | ||
| num_channels, | ||
| num_tiles_height, | ||
| height // num_tiles_height, | ||
| num_tiles_width, | ||
| width // num_tiles_width, | ||
| ) | ||
| # Permute dimensions to reorder the axes | ||
| image = images.permute(0, 2, 4, 1, 3, 5).contiguous() | ||
| # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) | ||
| image = image.view( | ||
| batch_size, | ||
| num_tiles_width * num_tiles_height, | ||
| num_channels, | ||
| height // num_tiles_height, | ||
| width // num_tiles_width, | ||
| ) | ||
| return image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this, gave it some thought. We're viewing the tensors to have the strides match, permuting, then calling contiguous and then re-viewing. It looks very similar to what an Unfold would do, getting a strided view directly. Out of scope for this PR but to keep in mind wrt optimizations.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely something to explore! However in this case it looks like Unfold doesn't work with uint8...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes it needs the division flexibility maybe, wasn't aware
| def _disable_grouping_output_flat(images, *paired_inputs): | ||
| """Build the disable_grouping output tuple for a flat list structure.""" | ||
| idx_range = range(len(images)) | ||
| images_dict = {i: images[i].unsqueeze(0) for i in idx_range} | ||
| paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs] | ||
| index_map = {i: (i, 0) for i in idx_range} | ||
| return images_dict, *paired_dicts, index_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clearer!
| def _validate_size(size: SizeDict) -> None: | ||
| if not (size.height and size.width): | ||
| raise ValueError(f"Argument `size` must be a dictionary with keys 'height' and 'width'. Got: {size}") | ||
| if size.height != size.width: | ||
| raise ValueError(f"Argument `size` must have the same height and width, got {size}") | ||
|
|
||
|
|
||
| def _validate_mllama_preprocess_arguments(do_resize, size, do_pad, max_image_tiles): | ||
| if not do_pad: | ||
| raise ValueError("MllamaImageProcessor doesn't support `do_pad=False` mode.") | ||
| if not do_resize: | ||
| raise ValueError("MllamaImageProcessor doesn't support `do_resize=False` mode.") | ||
| if max_image_tiles is None or max_image_tiles <= 0: | ||
| raise ValueError(f"MllamaImageProcessor `max_image_tiles` must be a positive integer, got {max_image_tiles}.") | ||
| _validate_size(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I understand, you were resolving it at load time with importlib or something like that?
| if disable_grouping: | ||
| if is_nested: | ||
| return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, { | ||
| (i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i])) | ||
| } | ||
| return _disable_grouping_output_nested(images, *paired_inputs) | ||
| else: | ||
| return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} | ||
| return _disable_grouping_output_flat(images, *paired_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the double if sounds more like a match use-case (pun intended 🤓), but fine as it is, it's clear to read
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, idefics2, llama4, mllama |
* Merge conflict * add fast processor * add fast processor * make style * add new convert rgb * use nested group by shape in mllama fast, add support for multiple inputs in group by shape * refactor after review --------- Co-authored-by: Vincent <[email protected]>
What does this PR do?
Finishes #37539
It also adds support to group other inputs that mirrors the shape of images in
group_images_by_shape, it could also be useful for some other future video processors @zucchini-nlp ;).