@@ -797,25 +797,61 @@ def flip_channel_order(
797797 return image
798798
799799
800+ def split_to_tiles (images : "torch.Tensor" , num_tiles_height : int , num_tiles_width : int ) -> "torch.Tensor" :
801+ # Split image into number of required tiles (width x height)
802+ batch_size , num_channels , height , width = images .size ()
803+ images = images .view (
804+ batch_size ,
805+ num_channels ,
806+ num_tiles_height ,
807+ height // num_tiles_height ,
808+ num_tiles_width ,
809+ width // num_tiles_width ,
810+ )
811+ # Permute dimensions to reorder the axes
812+ image = images .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
813+ # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
814+ image = image .view (
815+ batch_size ,
816+ num_tiles_width * num_tiles_height ,
817+ num_channels ,
818+ height // num_tiles_height ,
819+ width // num_tiles_width ,
820+ )
821+ return image
822+
823+
800824def _cast_tensor_to_float (x ):
801825 if x .is_floating_point ():
802826 return x
803827 return x .float ()
804828
805829
806- def _group_images_by_shape (nested_images , is_nested : bool = False ):
807- """Helper function to flatten a single level of nested image structures and group by shape."""
830+ def _group_images_by_shape (nested_images , * paired_inputs , is_nested : bool = False ):
831+ """Helper function to flatten a single level of nested image and batch structures and group by shape."""
808832 grouped_images = defaultdict (list )
809833 grouped_images_index = {}
810- nested_images = [nested_images ] if not is_nested else nested_images
811- for i , sublist in enumerate (nested_images ):
812- for j , image in enumerate (sublist ):
834+ paired_grouped_values = [defaultdict (list ) for _ in paired_inputs ]
835+
836+ # Normalize inputs to consistent nested structure
837+ normalized_images = [nested_images ] if not is_nested else nested_images
838+ normalized_paired = []
839+ for paired_input in paired_inputs :
840+ normalized_paired .append ([paired_input ] if not is_nested else paired_input )
841+
842+ # Process each image and group by shape
843+ for i , (sublist , * paired_sublists ) in enumerate (zip (normalized_images , * normalized_paired )):
844+ for j , (image , * paired_values ) in enumerate (zip (sublist , * paired_sublists )):
813845 key = (i , j ) if is_nested else j
814846 shape = image .shape [1 :]
847+
848+ # Add to grouped structures
815849 grouped_images [shape ].append (image )
850+ for paired_index , paired_value in enumerate (paired_values ):
851+ paired_grouped_values [paired_index ][shape ].append (paired_value )
816852 grouped_images_index [key ] = (shape , len (grouped_images [shape ]) - 1 )
817853
818- return grouped_images , grouped_images_index
854+ return grouped_images , * paired_grouped_values , grouped_images_index
819855
820856
821857def _reconstruct_nested_structure (indices , processed_images ):
@@ -844,13 +880,35 @@ def _reconstruct_nested_structure(indices, processed_images):
844880 return result
845881
846882
883+ def _disable_grouping_output_nested (images , * paired_inputs ):
884+ """Build the disable_grouping output tuple for a single-level nested structure."""
885+ outer_range = range (len (images ))
886+ inner_ranges = [range (len (images [i ])) for i in outer_range ]
887+
888+ # Precompute all (i, j) pairs
889+ ij_pairs = [(i , j ) for i in outer_range for j in inner_ranges [i ]]
890+
891+ images_dict = {(i , j ): images [i ][j ].unsqueeze (0 ) for (i , j ) in ij_pairs }
892+ paired_dicts = [{(i , j ): paired_list [i ][j ].unsqueeze (0 ) for (i , j ) in ij_pairs } for paired_list in paired_inputs ]
893+ index_map = {(i , j ): ((i , j ), 0 ) for (i , j ) in ij_pairs }
894+ return images_dict , * paired_dicts , index_map
895+
896+
897+ def _disable_grouping_output_flat (images , * paired_inputs ):
898+ """Build the disable_grouping output tuple for a flat list structure."""
899+ idx_range = range (len (images ))
900+ images_dict = {i : images [i ].unsqueeze (0 ) for i in idx_range }
901+ paired_dicts = [{i : paired_list [i ].unsqueeze (0 ) for i in idx_range } for paired_list in paired_inputs ]
902+ index_map = {i : (i , 0 ) for i in idx_range }
903+ return images_dict , * paired_dicts , index_map
904+
905+
847906def group_images_by_shape (
848907 images : Union [list ["torch.Tensor" ], "torch.Tensor" ],
849- disable_grouping : bool ,
908+ * paired_inputs ,
909+ disable_grouping : Optional [bool ],
850910 is_nested : bool = False ,
851- ) -> tuple [
852- dict [tuple [int , int ], list ["torch.Tensor" ]], dict [Union [int , tuple [int , int ]], tuple [tuple [int , int ], int ]]
853- ]:
911+ ) -> tuple [dict , ...]:
854912 """
855913 Groups images by shape.
856914 Returns a dictionary with the shape as key and a list of images with that shape as value,
@@ -862,15 +920,22 @@ def group_images_by_shape(
862920 Args:
863921 images (Union[list["torch.Tensor"], "torch.Tensor"]):
864922 A list of images or a single tensor
923+ *paired_inputs (Any):
924+ Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
925+ `is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
926+ same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
927+ they do not need to be tensors.
865928 disable_grouping (bool):
866929 Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
867930 This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
868931 is_nested (bool, *optional*, defaults to False):
869932 Whether the images are nested.
870933
871934 Returns:
872- tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
873- - A dictionary with shape as key and list of images with that shape as value
935+ tuple[dict, ...]:
936+ - A dictionary with shape as key and list/batch of images with that shape as value
937+ - Zero or more dictionaries (one per argument in `*paired_inputs`) grouped consistently with `images`; these carry
938+ the corresponding per-item values and are not stacked
874939 - A dictionary mapping original indices to (shape, index) tuples
875940 """
876941 # If disable grouping is not explicitly provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
@@ -880,19 +945,19 @@ def group_images_by_shape(
880945
881946 if disable_grouping :
882947 if is_nested :
883- return {(i , j ): images [i ][j ].unsqueeze (0 ) for i in range (len (images )) for j in range (len (images [i ]))}, {
884- (i , j ): ((i , j ), 0 ) for i in range (len (images )) for j in range (len (images [i ]))
885- }
948+ return _disable_grouping_output_nested (images , * paired_inputs )
886949 else :
887- return { i : images [ i ]. unsqueeze ( 0 ) for i in range ( len ( images ))}, { i : ( i , 0 ) for i in range ( len ( images ))}
950+ return _disable_grouping_output_flat ( images , * paired_inputs )
888951
889952 # Handle single level nested structure
890- grouped_images , grouped_images_index = _group_images_by_shape (images , is_nested )
953+ grouped_images , * paired_grouped_values , grouped_images_index = _group_images_by_shape (
954+ images , * paired_inputs , is_nested = is_nested
955+ )
891956
892957 # Stack images with the same shape
893958 grouped_images = {shape : torch .stack (images_list , dim = 0 ) for shape , images_list in grouped_images .items ()}
894959
895- return grouped_images , grouped_images_index
960+ return grouped_images , * paired_grouped_values , grouped_images_index
896961
897962
898963def reorder_images (
0 commit comments