diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 5fb87c345ef0..4028c38ff227 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -79,8 +79,6 @@ def validate_fast_preprocess_arguments( do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, - do_pad: Optional[bool] = None, - size_divisibility: Optional[int] = None, do_center_crop: Optional[bool] = None, crop_size: Optional[SizeDict] = None, do_resize: Optional[bool] = None, @@ -99,8 +97,6 @@ def validate_fast_preprocess_arguments( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size_divisibility, do_center_crop=do_center_crop, crop_size=crop_size, do_resize=do_resize, @@ -181,6 +177,8 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False): do_normalize: Optional[bool] image_mean: Optional[Union[float, list[float]]] image_std: Optional[Union[float, list[float]]] + do_pad: Optional[bool] + pad_size: Optional[dict[str, int]] do_convert_rgb: Optional[bool] return_tensors: Optional[Union[str, TensorType]] data_format: Optional[ChannelDimension] @@ -199,6 +197,8 @@ class BaseImageProcessorFast(BaseImageProcessor): crop_size = None do_resize = None do_center_crop = None + do_pad = None + pad_size = None do_rescale = None rescale_factor = 1 / 255 do_normalize = None @@ -222,6 +222,9 @@ def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]): ) crop_size = kwargs.pop("crop_size", self.crop_size) self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None + pad_size = kwargs.pop("pad_size", self.pad_size) + self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None + for key in self.valid_kwargs.__annotations__: kwarg = kwargs.pop(key, None) if kwarg is not None: @@ -239,6 +242,74 @@ def is_fast(self) -> bool: """ return True + def pad( + self, + images: "torch.Tensor", + pad_size: SizeDict = None, + fill_value: Optional[int] = 0, + padding_mode: Optional[str] = "constant", + return_mask: Optional[bool] = False, + disable_grouping: Optional[bool] = False, + **kwargs, + ) -> "torch.Tensor": + """ + Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch. + + Args: + images (`torch.Tensor`): + Images to pad. + pad_size (`SizeDict`, *optional*): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + fill_value (`int`, *optional*, defaults to `0`): + The constant value used to fill the padded area. + padding_mode (`str`, *optional*, defaults to "constant"): + The padding mode to use. Can be any of the modes supported by + `torch.nn.functional.pad` (e.g. constant, reflection, replication). + return_mask (`bool`, *optional*, defaults to `False`): + Whether to return a pixel mask to denote padded regions. + disable_grouping (`bool`, *optional*, defaults to `False`): + Whether to disable grouping of images by size. + + Returns: + `torch.Tensor`: The resized image. + """ + if pad_size is not None: + if not (pad_size.height and pad_size.width): + raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.") + pad_size = (pad_size.height, pad_size.width) + else: + pad_size = get_max_height_width(images) + + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_masks_grouped = {} + for shape, stacked_images in grouped_images.items(): + image_size = stacked_images.shape[-2:] + padding_height = pad_size[0] - image_size[0] + padding_width = pad_size[1] - image_size[1] + if padding_height < 0 or padding_width < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the " + f"image size. Got pad_size={pad_size}, image_size={image_size}." + ) + if image_size != pad_size: + padding = (0, 0, padding_width, padding_height) + stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode) + processed_images_grouped[shape] = stacked_images + + if return_mask: + # keep only one from the channel dimension in pixel mask + stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :] + stacked_masks[..., : image_size[0], : image_size[1]] = 1 + processed_masks_grouped[shape] = stacked_masks + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + if return_mask: + processed_masks = reorder_images(processed_masks_grouped, grouped_images_index) + return processed_images, processed_masks + + return processed_images + def resize( self, image: "torch.Tensor", @@ -577,6 +648,7 @@ def _further_process_kwargs( self, size: Optional[SizeDict] = None, crop_size: Optional[SizeDict] = None, + pad_size: Optional[SizeDict] = None, default_to_square: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, @@ -593,6 +665,8 @@ def _further_process_kwargs( size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if crop_size is not None: crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) + if pad_size is not None: + pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size")) if isinstance(image_mean, list): image_mean = tuple(image_mean) if isinstance(image_std, list): @@ -602,6 +676,7 @@ def _further_process_kwargs( kwargs["size"] = size kwargs["crop_size"] = crop_size + kwargs["pad_size"] = pad_size kwargs["image_mean"] = image_mean kwargs["image_std"] = image_std kwargs["data_format"] = data_format @@ -714,6 +789,8 @@ def _preprocess( do_normalize: bool, image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], + do_pad: Optional[bool], + pad_size: Optional[SizeDict], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -739,10 +816,12 @@ def _preprocess( stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std ) processed_images_grouped[shape] = stacked_images - processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + if do_pad: + processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping) + + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def to_dict(self): diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index cb7c4bbf422a..2079c21f3b0c 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -525,7 +525,7 @@ def validate_preprocess_arguments( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_pad: Optional[bool] = None, - size_divisibility: Optional[int] = None, + pad_size: Optional[Union[dict[str, int], int]] = None, do_center_crop: Optional[bool] = None, crop_size: Optional[dict[str, int]] = None, do_resize: Optional[bool] = None, @@ -544,10 +544,15 @@ def validate_preprocess_arguments( if do_rescale and rescale_factor is None: raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.") - if do_pad and size_divisibility is None: - # Here, size_divisor might be passed as the value of size + if do_pad and pad_size is None: + # Processors pad images using different args depending on the model, so the below check is pointless + # but we keep it for BC for now. TODO: remove in v5 + # Usually padding can be called with: + # - "pad_size/size" if we're padding to specific values + # - "size_divisor" if we're padding to any value divisible by X + # - "None" if we're padding to the maximum size image in batch raise ValueError( - "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`." + "Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`." ) if do_normalize and (image_mean is None or image_std is None): diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py index 28145b337a68..cb39ed097561 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py @@ -480,8 +480,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size_divisor, do_center_crop=do_center_crop, crop_size=crop_size, do_resize=do_resize, diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py index 64610ec4462a..4a7450c84498 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py @@ -25,7 +25,6 @@ SizeDict, TensorType, Unpack, - get_max_height_width, group_images_by_shape, reorder_images, ) @@ -99,13 +98,9 @@ class BridgeTowerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): size_divisor (`int`, *optional*, defaults to 32): The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize` is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by - the `do_pad` parameter in the `preprocess` method. """ size_divisor: Optional[int] - do_pad: Optional[bool] @auto_docstring @@ -224,59 +219,6 @@ def _pad_image( ) return padded_image - def pad( - self, - images: list["torch.Tensor"], - constant_values: Union[float, Iterable[float]] = 0, - return_pixel_mask: bool = True, - disable_grouping: Optional[bool] = False, - ) -> tuple: - """ - Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width - in the batch and optionally returns their corresponding pixel mask. - - Args: - image (`torch.Tensor`): - Image to pad. - constant_values (`float` or `Iterable[float]`, *optional*): - The value to use for the padding if `mode` is `"constant"`. - return_pixel_mask (`bool`, *optional*, defaults to `True`): - Whether to return a pixel mask. - disable_grouping (`bool`, *optional*, defaults to `False`): - Whether to disable grouping of images by size. - return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Can be one of: - - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - """ - pad_size = get_max_height_width(images) - - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) - processed_images_grouped = {} - processed_masks_grouped = {} - for shape, stacked_images in grouped_images.items(): - stacked_images = self._pad_image( - stacked_images, - pad_size, - constant_values=constant_values, - ) - processed_images_grouped[shape] = stacked_images - - if return_pixel_mask: - stacked_masks = make_pixel_mask(image=stacked_images, output_size=pad_size) - processed_masks_grouped[shape] = stacked_masks - - processed_images = reorder_images(processed_images_grouped, grouped_images_index) - - processed_masks = None - if return_pixel_mask: - processed_masks = reorder_images(processed_masks_grouped, grouped_images_index) - - return processed_images, processed_masks - def _preprocess( self, images: list["torch.Tensor"], @@ -325,7 +267,7 @@ def _preprocess( data = {} if do_pad: processed_images, processed_masks = self.pad( - processed_images, return_pixel_mask=True, disable_grouping=disable_grouping + processed_images, return_mask=True, disable_grouping=disable_grouping ) processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks data["pixel_mask"] = processed_masks diff --git a/src/transformers/models/bridgetower/processing_bridgetower.py b/src/transformers/models/bridgetower/processing_bridgetower.py index 030c578c49cd..6d7059c4c5a5 100644 --- a/src/transformers/models/bridgetower/processing_bridgetower.py +++ b/src/transformers/models/bridgetower/processing_bridgetower.py @@ -16,10 +16,17 @@ Processor class for BridgeTower. """ -from ...processing_utils import ProcessingKwargs, ProcessorMixin +from typing import Optional + +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin + + +class BridgeTowerImagesKwargs(ImagesKwargs): + size_divisor: Optional[int] class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: BridgeTowerImagesKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py b/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py index 6b7c8327dc89..afe76134bc8d 100644 --- a/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +++ b/src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py @@ -227,6 +227,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: if crop_to_patches: grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py index 06ef3f431050..86e51f2b4a60 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py @@ -74,23 +74,12 @@ class ConditionalDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Controls whether to convert the annotations to the format expected by the CONDITIONAL_DETR model. Converts the bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. return_segmentation_masks (`bool`, *optional*, defaults to `False`): Whether to return segmentation masks. """ format: Optional[Union[str, AnnotationFormat]] do_convert_annotations: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] return_segmentation_masks: Optional[bool] @@ -629,7 +618,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -698,7 +687,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/convnext/image_processing_convnext_fast.py b/src/transformers/models/convnext/image_processing_convnext_fast.py index 130fcc19639e..0866b230a52e 100644 --- a/src/transformers/models/convnext/image_processing_convnext_fast.py +++ b/src/transformers/models/convnext/image_processing_convnext_fast.py @@ -155,6 +155,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py index 1a9444cbf9db..7ab4e98012ac 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl.py @@ -89,6 +89,8 @@ class DeepseekVLImageProcessor(BaseImageProcessor): Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to square or not. """ model_input_names = ["pixel_values"] @@ -105,6 +107,7 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: Optional[bool] = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -121,6 +124,7 @@ def __init__( self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb + self.do_pad = do_pad self.min_size = min_size if image_mean is None: self.background_color = (127, 127, 127) @@ -131,7 +135,6 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], - background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -145,8 +148,6 @@ def resize( Image to resize. size (`dict[str, int]` or `int`): The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. - background_color (`tuple[int, int, int]`): - The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -165,7 +166,6 @@ def resize( Returns: `np.ndarray`: The resized image. """ - background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -194,12 +194,6 @@ def resize( input_data_format=input_data_format, **kwargs, ) - # Expand and pad the images to obtain a square image of dimensions `size x size` - image = self.pad_to_square( - image=image, - background_color=background_color, - input_data_format=input_data_format, - ) return image @filter_out_non_signature_kwargs() @@ -216,6 +210,8 @@ def preprocess( image_std: Optional[Union[float, list[float]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, do_convert_rgb: Optional[bool] = None, + background_color: Optional[Union[int, tuple[int, int, int]]] = None, + do_pad: Optional[bool] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> PIL.Image.Image: @@ -247,6 +243,10 @@ def preprocess( Image standard deviation to normalize the image by if `do_normalize` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to square or not. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -274,6 +274,8 @@ def preprocess( image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + background_color = background_color if background_color is not None else self.background_color size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) @@ -319,6 +321,17 @@ def preprocess( for image in images ] + if do_pad: + # Expand and pad the images to obtain a square image of dimensions `size x size` + images = [ + self.pad_to_square( + image=image, + background_color=background_color, + input_data_format=input_data_format, + ) + for image in images + ] + if do_rescale: images = [ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) diff --git a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py index 2204606d4211..7764a8250159 100644 --- a/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +++ b/src/transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py @@ -62,6 +62,7 @@ class DeepseekVLImageProcessorFast(BaseImageProcessorFast): do_resize = True do_rescale = True do_normalize = True + do_pad = True valid_kwargs = DeepseekVLFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[DeepseekVLFastImageProcessorKwargs]): diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py index 45e19da0d14c..7c7d6df82424 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py @@ -102,6 +102,8 @@ class DeepseekVLHybridImageProcessor(BaseImageProcessor): number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to square or not. """ model_input_names = ["pixel_values", "high_res_pixel_values"] @@ -122,6 +124,7 @@ def __init__( high_res_image_mean: Optional[Union[float, list[float]]] = None, high_res_image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -147,6 +150,7 @@ def __init__( self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb + self.do_pad = do_pad self.min_size = min_size if image_mean is None: self.background_color = (127, 127, 127) @@ -162,7 +166,6 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], - background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -176,8 +179,6 @@ def resize( Image to resize. size (`dict[str, int]` or `int`): The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. - background_color (`tuple[int, int, int]`): - The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -196,7 +197,6 @@ def resize( Returns: `np.ndarray`: The resized image. """ - background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -225,12 +225,6 @@ def resize( input_data_format=input_data_format, **kwargs, ) - # Expand and pad the images to obtain a square image of dimensions `size x size` - image = self.pad_to_square( - image=image, - background_color=background_color, - input_data_format=input_data_format, - ) return image @filter_out_non_signature_kwargs() @@ -253,6 +247,8 @@ def preprocess( data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: Optional[bool] = None, + background_color: Optional[tuple[int, int, int]] = None, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. @@ -309,6 +305,10 @@ def preprocess( - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to square or not. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. """ do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale @@ -321,6 +321,8 @@ def preprocess( high_res_image_mean = high_res_image_mean if high_res_image_mean is not None else self.high_res_image_mean high_res_image_std = high_res_image_std if high_res_image_std is not None else self.high_res_image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + background_color = background_color if background_color is not None else self.background_color size = size if size is not None else self.size size_dict = get_size_dict(size) @@ -372,17 +374,28 @@ def preprocess( high_res_image = self.resize( image=high_res_image, size=high_res_size_dict, - background_color=self.high_res_background_color, resample=high_res_resample, input_data_format=input_data_format, ) + if do_pad: + # Expand and pad the images to obtain a square image of dimensions `size x size` + high_res_image = self.pad_to_square( + image=high_res_image, + background_color=background_color, + input_data_format=input_data_format, + ) image = self.resize( image=high_res_image, size=size_dict, - background_color=self.background_color, resample=resample, input_data_format=input_data_format, ) + if do_pad: + image = self.pad_to_square( + image=image, + background_color=background_color, + input_data_format=input_data_format, + ) if do_rescale: image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) diff --git a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py index d55610331f30..3770cf18303e 100644 --- a/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +++ b/src/transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py @@ -86,6 +86,7 @@ class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast): do_resize = True do_rescale = True do_normalize = True + do_pad = True valid_kwargs = DeepseekVLHybridFastImageProcessorKwargs high_res_image_mean = OPENAI_CLIP_MEAN high_res_image_std = OPENAI_CLIP_STD diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 6c36cfa50daa..c6cf71b09613 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -488,6 +488,8 @@ class DeepseekVLHybridImageProcessor(DeepseekVLImageProcessor): number of channels in the image. Can be overridden by the `high_res_image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to square or not. """ model_input_names = ["pixel_values", "high_res_pixel_values"] @@ -508,6 +510,7 @@ def __init__( high_res_image_mean: Optional[Union[float, list[float]]] = None, high_res_image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: bool = True, **kwargs, ) -> None: high_res_size = high_res_size if high_res_size is not None else {"height": 1024, "width": 1024} @@ -531,6 +534,7 @@ def __init__( image_mean=image_mean, image_std=image_std, do_convert_rgb=do_convert_rgb, + do_pad=do_pad, **kwargs, ) @@ -559,6 +563,8 @@ def preprocess( data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: Optional[bool] = None, + background_color: Optional[tuple[int, int, int]] = None, ): """ Preprocess an image or batch of images. @@ -615,6 +621,10 @@ def preprocess( - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to square or not. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. """ do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale @@ -627,6 +637,8 @@ def preprocess( high_res_image_mean = high_res_image_mean if high_res_image_mean is not None else self.high_res_image_mean high_res_image_std = high_res_image_std if high_res_image_std is not None else self.high_res_image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + background_color = background_color if background_color is not None else self.background_color size = size if size is not None else self.size size_dict = get_size_dict(size) @@ -678,17 +690,28 @@ def preprocess( high_res_image = self.resize( image=high_res_image, size=high_res_size_dict, - background_color=self.high_res_background_color, resample=high_res_resample, input_data_format=input_data_format, ) + if do_pad: + # Expand and pad the images to obtain a square image of dimensions `size x size` + high_res_image = self.pad_to_square( + image=high_res_image, + background_color=background_color, + input_data_format=input_data_format, + ) image = self.resize( image=high_res_image, size=size_dict, - background_color=self.background_color, resample=resample, input_data_format=input_data_format, ) + if do_pad: + image = self.pad_to_square( + image=image, + background_color=background_color, + input_data_format=input_data_format, + ) if do_rescale: image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py index b6cd0a7075f3..2bfbedddc5d0 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py @@ -65,23 +65,12 @@ class DeformableDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Controls whether to convert the annotations to the format expected by the DEFORMABLE_DETR model. Converts the bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. return_segmentation_masks (`bool`, *optional*, defaults to `False`): Whether to return segmentation masks. """ format: Optional[Union[str, AnnotationFormat]] do_convert_annotations: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] return_segmentation_masks: Optional[bool] @@ -620,7 +609,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -689,7 +678,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py index 581577b5b25f..d27220c3d2be 100644 --- a/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py +++ b/src/transformers/models/depth_pro/image_processing_depth_pro_fast.py @@ -78,6 +78,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: # Group images by size for batched scaling grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index 9877729434e1..ba216a6f2d49 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -286,23 +286,12 @@ class DetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Controls whether to convert the annotations to the format expected by the DETR model. Converts the bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. return_segmentation_masks (`bool`, *optional*, defaults to `False`): Whether to return segmentation masks. """ format: Optional[Union[str, AnnotationFormat]] do_convert_annotations: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] return_segmentation_masks: Optional[bool] @@ -641,7 +630,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -710,7 +699,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py b/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py index bfb9d1074f14..fba0d3089438 100644 --- a/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +++ b/src/transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py @@ -70,6 +70,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index 570981decf61..7dec96422c5d 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -215,10 +215,6 @@ def pad_image( padding = ((pad_top, pad_bottom), (pad_left, pad_right)) return pad(image, padding, data_format=data_format, input_data_format=input_data_format) - def pad(self, *args, **kwargs): - logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") - return self.pad_image(*args, **kwargs) - def thumbnail( self, image: np.ndarray, @@ -412,8 +408,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. do_resize=do_resize, size=size, resample=resample, diff --git a/src/transformers/models/donut/image_processing_donut_fast.py b/src/transformers/models/donut/image_processing_donut_fast.py index 8ec023554417..23714affe1e8 100644 --- a/src/transformers/models/donut/image_processing_donut_fast.py +++ b/src/transformers/models/donut/image_processing_donut_fast.py @@ -49,15 +49,10 @@ class DonutFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Whether to resize the image using thumbnail method. do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. - do_pad (`bool`, *optional*, defaults to `self.do_pad`): - Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random - amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are - padded to the largest image size in the batch. """ do_thumbnail: Optional[bool] do_align_long_axis: Optional[bool] - do_pad: Optional[bool] @auto_docstring diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index bad8ef3b3c40..9b28950d2ded 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -541,8 +541,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size_divisor, do_resize=do_resize, size=size, resample=resample, diff --git a/src/transformers/models/dpt/image_processing_dpt_fast.py b/src/transformers/models/dpt/image_processing_dpt_fast.py index 1387127b4cf0..7fce8a9f64db 100644 --- a/src/transformers/models/dpt/image_processing_dpt_fast.py +++ b/src/transformers/models/dpt/image_processing_dpt_fast.py @@ -64,9 +64,6 @@ class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ensure_multiple_of (`int`, *optional*, defaults to 1): If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overridden by `ensure_multiple_of` in `preprocess`. - do_pad (`bool`, *optional*, defaults to `False`): - Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in - combination with DPT. size_divisor (`int`, *optional*): If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the DINOv2 paper, which uses the model in combination with DPT. @@ -81,7 +78,6 @@ class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ensure_multiple_of: Optional[int] size_divisor: Optional[int] - do_pad: Optional[bool] keep_aspect_ratio: Optional[bool] do_reduce_labels: Optional[bool] diff --git a/src/transformers/models/dpt/modular_dpt.py b/src/transformers/models/dpt/modular_dpt.py index f86b5601dada..7ae6bb40c3af 100644 --- a/src/transformers/models/dpt/modular_dpt.py +++ b/src/transformers/models/dpt/modular_dpt.py @@ -94,9 +94,6 @@ class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ensure_multiple_of (`int`, *optional*, defaults to 1): If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overridden by `ensure_multiple_of` in `preprocess`. - do_pad (`bool`, *optional*, defaults to `False`): - Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in - combination with DPT. size_divisor (`int`, *optional*): If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the DINOv2 paper, which uses the model in combination with DPT. @@ -111,7 +108,6 @@ class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ensure_multiple_of: Optional[int] size_divisor: Optional[int] - do_pad: Optional[bool] keep_aspect_ratio: Optional[bool] do_reduce_labels: Optional[bool] diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index 29af98ed5072..e52d9dc8ee91 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -455,8 +455,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. do_resize=do_resize, size=size, resample=resample, diff --git a/src/transformers/models/gemma3/image_processing_gemma3_fast.py b/src/transformers/models/gemma3/image_processing_gemma3_fast.py index 6ce7b508b270..3826f40bd997 100644 --- a/src/transformers/models/gemma3/image_processing_gemma3_fast.py +++ b/src/transformers/models/gemma3/image_processing_gemma3_fast.py @@ -194,8 +194,6 @@ def _preprocess( pan_and_scan_max_num_crops: Optional[int], pan_and_scan_min_ratio_to_activate: Optional[float], interpolation: Optional["F.InterpolationMode"], - do_center_crop: bool, - crop_size: SizeDict, do_rescale: bool, rescale_factor: float, do_normalize: bool, @@ -203,6 +201,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: # Group images by size for batched processing processed_images_grouped = {} diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py index 6652e018263c..38b87aed623f 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py @@ -173,6 +173,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: if crop_to_patches: grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py b/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py index 9869e8eb4801..59866c9a410e 100644 --- a/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +++ b/src/transformers/models/grounding_dino/image_processing_grounding_dino_fast.py @@ -68,23 +68,12 @@ class GroundingDinoFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Controls whether to convert the annotations to the format expected by the GROUNDING_DINO model. Converts the bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. return_segmentation_masks (`bool`, *optional*, defaults to `False`): Whether to return segmentation masks. """ format: Optional[Union[str, AnnotationFormat]] do_convert_annotations: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] return_segmentation_masks: Optional[bool] @@ -651,7 +640,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -720,7 +709,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/idefics2/image_processing_idefics2_fast.py b/src/transformers/models/idefics2/image_processing_idefics2_fast.py index a22b95cfea97..5348bda389ed 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2_fast.py +++ b/src/transformers/models/idefics2/image_processing_idefics2_fast.py @@ -109,12 +109,9 @@ class Idefics2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ do_image_splitting (`bool`, *optional*, defaults to `False`): Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. - do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad images to the largest height and width in the batch. """ do_image_splitting: Optional[bool] - do_pad: Optional[bool] @auto_docstring diff --git a/src/transformers/models/idefics3/image_processing_idefics3_fast.py b/src/transformers/models/idefics3/image_processing_idefics3_fast.py index a6047ba77a87..5b0c0e6180f9 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3_fast.py +++ b/src/transformers/models/idefics3/image_processing_idefics3_fast.py @@ -171,9 +171,6 @@ def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "tor class Idefics3FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ - do_pad (`bool`, *optional*): - Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest - number of patches in the batch. Padding will be applied to the bottom and right with zeros. do_image_splitting (`bool`, *optional*, defaults to `True`): Whether to split the image into sub-images concatenated with the original image. They are split into patches such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`. @@ -183,7 +180,6 @@ class Idefics3FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Whether to return the row and column information of the images. """ - do_pad: Optional[bool] do_image_splitting: Optional[bool] max_image_size: Optional[dict[str, int]] return_row_col_info: Optional[bool] diff --git a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py index f4f482c56313..a2cd3cf351d2 100644 --- a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py @@ -61,12 +61,10 @@ def _preprocess( do_convert_rgb: bool, do_resize: bool, size: SizeDict, - size_divisor: Optional[int], interpolation: Optional["F.InterpolationMode"], do_center_crop: bool, crop_size: SizeDict, do_rescale: bool, - do_pad: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, list[float]]], @@ -81,9 +79,7 @@ def _preprocess( if do_convert_rgb: stacked_videos = self.convert_to_rgb(stacked_videos) if do_resize: - stacked_videos = self.resize( - stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation - ) + stacked_videos = self.resize(stacked_videos, size=size, interpolation=interpolation) resized_videos_grouped[shape] = stacked_videos resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) diff --git a/src/transformers/models/internvl/video_processing_internvl.py b/src/transformers/models/internvl/video_processing_internvl.py index 3c0ee8de1bef..a2e06d3b7ec4 100644 --- a/src/transformers/models/internvl/video_processing_internvl.py +++ b/src/transformers/models/internvl/video_processing_internvl.py @@ -110,12 +110,10 @@ def _preprocess( do_convert_rgb: bool, do_resize: bool, size: SizeDict, - size_divisor: Optional[int], interpolation: Optional["F.InterpolationMode"], do_center_crop: bool, crop_size: SizeDict, do_rescale: bool, - do_pad: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, list[float]]], @@ -130,9 +128,7 @@ def _preprocess( if do_convert_rgb: stacked_videos = self.convert_to_rgb(stacked_videos) if do_resize: - stacked_videos = self.resize( - stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation - ) + stacked_videos = self.resize(stacked_videos, size=size, interpolation=interpolation) resized_videos_grouped[shape] = stacked_videos resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) diff --git a/src/transformers/models/janus/image_processing_janus.py b/src/transformers/models/janus/image_processing_janus.py index 3669e707928b..16659bd85354 100644 --- a/src/transformers/models/janus/image_processing_janus.py +++ b/src/transformers/models/janus/image_processing_janus.py @@ -86,6 +86,8 @@ class JanusImageProcessor(BaseImageProcessor): Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to square or not. """ model_input_names = ["pixel_values"] @@ -102,6 +104,7 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: Optional[bool] = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -118,6 +121,7 @@ def __init__( self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb + self.do_pad = do_pad self.min_size = min_size if image_mean is None: self.background_color = (127, 127, 127) @@ -128,7 +132,6 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], - background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -142,8 +145,6 @@ def resize( Image to resize. size (`dict[str, int]` or `int`): The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. - background_color (`tuple[int, int, int]`): - The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -162,7 +163,6 @@ def resize( Returns: `np.ndarray`: The resized image. """ - background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -191,12 +191,6 @@ def resize( input_data_format=input_data_format, **kwargs, ) - # Expand and pad the images to obtain a square image of dimensions `size x size` - image = self.pad_to_square( - image=image, - background_color=background_color, - input_data_format=input_data_format, - ) return image @filter_out_non_signature_kwargs() @@ -213,6 +207,8 @@ def preprocess( image_std: Optional[Union[float, list[float]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, do_convert_rgb: Optional[bool] = None, + background_color: Optional[Union[int, tuple[int, int, int]]] = None, + do_pad: Optional[bool] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> PIL.Image.Image: @@ -244,6 +240,10 @@ def preprocess( Image standard deviation to normalize the image by if `do_normalize` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to square or not. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -271,6 +271,8 @@ def preprocess( image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + background_color = background_color if background_color is not None else self.background_color size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) @@ -316,6 +318,17 @@ def preprocess( for image in images ] + if do_pad: + # Expand and pad the images to obtain a square image of dimensions `size x size` + images = [ + self.pad_to_square( + image=image, + background_color=background_color, + input_data_format=input_data_format, + ) + for image in images + ] + if do_rescale: images = [ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) diff --git a/src/transformers/models/janus/image_processing_janus_fast.py b/src/transformers/models/janus/image_processing_janus_fast.py index eedf18e2c19f..3e9483f21bfe 100644 --- a/src/transformers/models/janus/image_processing_janus_fast.py +++ b/src/transformers/models/janus/image_processing_janus_fast.py @@ -68,6 +68,7 @@ class JanusImageProcessorFast(BaseImageProcessorFast): do_resize = True do_rescale = True do_normalize = True + do_pad = True valid_kwargs = JanusFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[JanusFastImageProcessorKwargs]): diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 261e994262aa..611ad5018345 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -29,23 +29,28 @@ from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList from ...generation.utils import GenerateDecoderOnlyOutput from ...image_processing_utils import BatchFeature, get_size_dict -from ...image_transforms import resize, to_channel_dimension_format +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, + is_scaled_image, make_flat_list_of_images, to_numpy_array, + valid_images, + validate_preprocess_arguments, ) from ...modeling_outputs import ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + TensorType, TransformersKwargs, auto_docstring, can_return_tuple, + filter_out_non_signature_kwargs, is_torch_available, is_vision_available, logging, @@ -1328,6 +1333,8 @@ class JanusImageProcessor(BlipImageProcessor): Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to square or not. """ def __init__( @@ -1342,10 +1349,12 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_convert_rgb: Optional[bool] = None, + do_pad: Optional[bool] = True, **kwargs, ): super().__init__(**kwargs) + self.do_pad = do_pad self.min_size = min_size if image_mean is None: self.background_color = (127, 127, 127) @@ -1430,7 +1439,6 @@ def resize( self, image: np.ndarray, size: Union[dict[str, int], int], - background_color: Optional[tuple[int, int, int]] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -1444,8 +1452,6 @@ def resize( Image to resize. size (`dict[str, int]` or `int`): The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`. - background_color (`tuple[int, int, int]`): - The background color to use for the padding. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): @@ -1464,7 +1470,6 @@ def resize( Returns: `np.ndarray`: The resized image. """ - background_color = background_color if background_color is not None else self.background_color if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -1493,14 +1498,164 @@ def resize( input_data_format=input_data_format, **kwargs, ) - # Expand and pad the images to obtain a square image of dimensions `size x size` - image = self.pad_to_square( - image=image, - background_color=background_color, - input_data_format=input_data_format, - ) return image + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: Optional[bool] = None, + background_color: Optional[Union[int, tuple[int, int, int]]] = None, + do_pad: Optional[bool] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + background_color (`tuple[int, int, int]`): + The background color to use for the padding. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image to square or not. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + background_color = background_color if background_color is not None else self.background_color + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_pad: + # Expand and pad the images to obtain a square image of dimensions `size x size` + images = [ + self.pad_to_square( + image=image, + background_color=background_color, + input_data_format=input_data_format, + ) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + + return encoded_outputs + def postprocess( self, images: ImageInput, diff --git a/src/transformers/models/llava/image_processing_llava_fast.py b/src/transformers/models/llava/image_processing_llava_fast.py index 02324f6393cd..cf62f250bc2f 100644 --- a/src/transformers/models/llava/image_processing_llava_fast.py +++ b/src/transformers/models/llava/image_processing_llava_fast.py @@ -56,14 +56,7 @@ from torchvision.transforms import functional as F -class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): - """ - Args: - do_pad (`bool`, *optional*): - Whether to pad the image to a square based on the longest edge. - """ - - do_pad: Optional[bool] +class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ... @auto_docstring @@ -147,6 +140,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py index 3dda73507006..201a65260589 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -59,13 +59,9 @@ class LlavaNextFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): A list of possible resolutions to use for processing high resolution images. The best resolution is selected based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` method. - do_pad (`bool`, *optional*): - Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest - number of patches in the batch. Padding will be applied to the bottom and right with zeros. """ image_grid_pinpoints: Optional[list[list[int]]] - do_pad: Optional[bool] @auto_docstring diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index 46ef482dad36..4392d64e9ebf 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -56,13 +56,9 @@ class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): A list of possible resolutions to use for processing high resolution images. The best resolution is selected based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` method. - do_pad (`bool`, *optional*): - Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest - number of patches in the batch. Padding will be applied to the bottom and right with zeros. """ image_grid_pinpoints: Optional[list[list[int]]] - do_pad: Optional[bool] @auto_docstring diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index 9d6d3a53f7c8..45dfac3b37ef 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -72,13 +72,9 @@ class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): A list of possible resolutions to use for processing high resolution images. The best resolution is selected based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess` method. - do_pad (`bool`, *optional*): - Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest - number of patches in the batch. Padding will be applied to the bottom and right with zeros. """ image_grid_pinpoints: Optional[list[list[int]]] - do_pad: Optional[bool] class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast): diff --git a/src/transformers/models/mask2former/image_processing_mask2former_fast.py b/src/transformers/models/mask2former/image_processing_mask2former_fast.py index b94f0d8c308c..c61d531eb077 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former_fast.py +++ b/src/transformers/models/mask2former/image_processing_mask2former_fast.py @@ -84,23 +84,12 @@ class Mask2FormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): The background label will be replaced by `ignore_index`. num_labels (`int`, *optional*): The number of labels in the segmentation map. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`Dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. """ size_divisor: Optional[int] ignore_index: Optional[int] do_reduce_labels: Optional[bool] num_labels: Optional[int] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] def convert_segmentation_map_to_binary_masks_fast( @@ -334,8 +323,8 @@ def _preprocess( segmentation_maps: Optional["torch.Tensor"], instance_id_to_semantic_id: Optional[dict[int, int]], do_resize: Optional[bool], - size: Optional[dict[str, int]], - pad_size: Optional[dict[str, int]], + size: Optional[SizeDict], + pad_size: Optional[SizeDict], size_divisor: Optional[int], interpolation: Optional[Union["PILImageResampling", "F.InterpolationMode"]], do_rescale: Optional[bool], @@ -383,7 +372,7 @@ def _preprocess( resized_segmentation_maps_grouped, grouped_segmentation_maps_index ) if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(resized_images) diff --git a/src/transformers/models/maskformer/image_processing_maskformer_fast.py b/src/transformers/models/maskformer/image_processing_maskformer_fast.py index ad5cb946d38d..0b1c95aa1012 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer_fast.py +++ b/src/transformers/models/maskformer/image_processing_maskformer_fast.py @@ -120,23 +120,12 @@ class MaskFormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): The background label will be replaced by `ignore_index`. num_labels (`int`, *optional*): The number of labels in the segmentation map. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`Dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. """ size_divisor: Optional[int] ignore_index: Optional[int] do_reduce_labels: Optional[bool] num_labels: Optional[int] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] @auto_docstring @@ -335,8 +324,8 @@ def _preprocess( segmentation_maps: Optional["torch.Tensor"], instance_id_to_semantic_id: Optional[dict[int, int]], do_resize: Optional[bool], - size: Optional[dict[str, int]], - pad_size: Optional[dict[str, int]], + size: Optional[SizeDict], + pad_size: Optional[SizeDict], size_divisor: Optional[int], interpolation: Optional[Union["PILImageResampling", "F.InterpolationMode"]], do_rescale: Optional[bool], @@ -384,7 +373,7 @@ def _preprocess( resized_segmentation_maps_grouped, grouped_segmentation_maps_index ) if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(resized_images) diff --git a/src/transformers/models/nougat/image_processing_nougat.py b/src/transformers/models/nougat/image_processing_nougat.py index 38b4b8fa4a50..0c0a51464b43 100644 --- a/src/transformers/models/nougat/image_processing_nougat.py +++ b/src/transformers/models/nougat/image_processing_nougat.py @@ -461,8 +461,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. do_resize=do_resize, size=size, resample=resample, diff --git a/src/transformers/models/nougat/image_processing_nougat_fast.py b/src/transformers/models/nougat/image_processing_nougat_fast.py index 136d7f171575..ebe37389f3f6 100644 --- a/src/transformers/models/nougat/image_processing_nougat_fast.py +++ b/src/transformers/models/nougat/image_processing_nougat_fast.py @@ -63,14 +63,11 @@ class NougatFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Whether to resize the image using thumbnail method. do_align_long_axis (`bool`, *optional*, defaults to `False`): Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. - do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad the images to the largest image size in the batch. """ do_crop_margin: Optional[bool] do_thumbnail: Optional[bool] do_align_long_axis: Optional[bool] - do_pad: Optional[bool] @auto_docstring diff --git a/src/transformers/models/oneformer/image_processing_oneformer_fast.py b/src/transformers/models/oneformer/image_processing_oneformer_fast.py index a61745e87e58..10869f50f622 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer_fast.py +++ b/src/transformers/models/oneformer/image_processing_oneformer_fast.py @@ -530,24 +530,9 @@ def pad( Returns: `BatchFeature`: Padded images and optional pixel masks. """ - pad_size = get_max_height_width(images) - - padded_images = [] - pixel_masks = [] - - for image in images: - padded_image = self._pad_image_fast( - image=image, - output_size=pad_size, - constant_values=0, - ) - padded_images.append(padded_image) - - if return_pixel_mask: - input_height, input_width = image.shape[1], image.shape[2] - mask = torch.zeros(pad_size, dtype=torch.int64, device=image.device) - mask[:input_height, :input_width] = 1 - pixel_masks.append(mask) + outputs = super().pad(images, return_mask=return_pixel_mask) + padded_images = outputs[0] if return_pixel_mask else outputs + pixel_masks = outputs[1] if return_pixel_mask else None if return_tensors: padded_images = torch.stack(padded_images, dim=0) diff --git a/src/transformers/models/ovis2/image_processing_ovis2_fast.py b/src/transformers/models/ovis2/image_processing_ovis2_fast.py index e5940421828d..f12a9c70ee57 100644 --- a/src/transformers/models/ovis2/image_processing_ovis2_fast.py +++ b/src/transformers/models/ovis2/image_processing_ovis2_fast.py @@ -202,6 +202,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: if crop_to_patches and max_patches > 1: grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) diff --git a/src/transformers/models/owlv2/image_processing_owlv2_fast.py b/src/transformers/models/owlv2/image_processing_owlv2_fast.py index fd46f12f28ee..926da9b27ffc 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2_fast.py +++ b/src/transformers/models/owlv2/image_processing_owlv2_fast.py @@ -60,14 +60,7 @@ from .image_processing_owlv2 import _scale_boxes, box_iou -class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): - r""" - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with grey pixels. - """ - - do_pad: Optional[bool] +class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ... @auto_docstring @@ -289,7 +282,12 @@ def pad( images: list["torch.Tensor"], disable_grouping: Optional[bool], constant_value: float = 0.5, + **kwargs, ) -> list["torch.Tensor"]: + """ + Unlike the Base class `self.pad` where all images are padded to the maximum image size, + Owlv2 pads an image to square. + """ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) processed_images_grouped = {} for shape, stacked_images in grouped_images.items(): @@ -389,7 +387,7 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) if do_pad: - processed_images = self.pad(processed_images, disable_grouping=disable_grouping) + processed_images = self.pad(processed_images, constant_value=0.5, disable_grouping=disable_grouping) grouped_images, grouped_images_index = group_images_by_shape( processed_images, disable_grouping=disable_grouping diff --git a/src/transformers/models/owlv2/modular_owlv2.py b/src/transformers/models/owlv2/modular_owlv2.py index 799b9bbaa704..7fe4d75ee9ea 100644 --- a/src/transformers/models/owlv2/modular_owlv2.py +++ b/src/transformers/models/owlv2/modular_owlv2.py @@ -52,14 +52,7 @@ from torchvision.transforms import functional as F -class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): - r""" - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with grey pixels. - """ - - do_pad: Optional[bool] +class Owlv2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): ... @auto_docstring @@ -102,7 +95,12 @@ def pad( images: list["torch.Tensor"], disable_grouping: Optional[bool], constant_value: float = 0.5, + **kwargs, ) -> list["torch.Tensor"]: + """ + Unlike the Base class `self.pad` where all images are padded to the maximum image size, + Owlv2 pads an image to square. + """ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) processed_images_grouped = {} for shape, stacked_images in grouped_images.items(): @@ -202,7 +200,7 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) if do_pad: - processed_images = self.pad(processed_images, disable_grouping=disable_grouping) + processed_images = self.pad(processed_images, constant_value=0.5, disable_grouping=disable_grouping) grouped_images, grouped_images_index = group_images_by_shape( processed_images, disable_grouping=disable_grouping diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py index 5d42bb097476..585405627023 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral_fast.py +++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -162,6 +162,7 @@ def _preprocess( image_std: Optional[Union[float, list[float]]], disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: patch_size = get_size_dict(patch_size, default_to_square=True) patch_size = SizeDict(**patch_size) diff --git a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py index a5fad19b1a1b..4f0f68240f9a 100644 --- a/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py @@ -378,8 +378,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size_divisor, do_resize=do_resize, size=size, resample=resample, diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index 9927a8d02209..eefc45bf9f9a 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -59,23 +59,12 @@ class RTDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. return_segmentation_masks (`bool`, *optional*, defaults to `False`): Whether to return segmentation masks. """ format: Optional[Union[str, AnnotationFormat]] do_convert_annotations: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] return_segmentation_masks: Optional[bool] @@ -424,7 +413,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -483,7 +472,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index e661b7189042..938f070d3672 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -175,7 +175,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -234,7 +234,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 33a3661c5e6d..c9b54f561fb6 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -516,8 +516,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size. do_resize=do_resize, size=size, resample=resample, diff --git a/src/transformers/models/sam/image_processing_sam_fast.py b/src/transformers/models/sam/image_processing_sam_fast.py index 77b4b490e136..1bfb6adf5234 100644 --- a/src/transformers/models/sam/image_processing_sam_fast.py +++ b/src/transformers/models/sam/image_processing_sam_fast.py @@ -26,8 +26,6 @@ from ...image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, - group_images_by_shape, - reorder_images, ) from ...image_utils import ( IMAGENET_DEFAULT_MEAN, @@ -40,7 +38,6 @@ ) from ...processing_utils import Unpack from ...utils import ( - TensorType, auto_docstring, is_torch_available, is_torchvision_available, @@ -62,12 +59,6 @@ class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): r""" - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. mask_size (`dict[str, int]`, *optional*): The size `{"longest_edge": int}` to resize the segmentation maps to. mask_pad_size (`dict[str, int]`, *optional*): @@ -76,8 +67,6 @@ class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ mask_size: Optional[dict[str, int]] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] mask_pad_size: Optional[dict[str, int]] @@ -102,15 +91,6 @@ class SamImageProcessorFast(BaseImageProcessorFast): def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]): super().__init__(**kwargs) - def pad_image(self, images: "torch.Tensor", pad_size: SizeDict): - """Pad images to the specified size.""" - output_height, output_width = pad_size.height, pad_size.width - input_height, input_width = images.shape[-2:] - pad_width = output_width - input_width - pad_height = output_height - input_height - padding = (0, 0, pad_width, pad_height) - return F_t.pad(images, padding) - def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int): """ Compute the output size given input size and target long side length. @@ -231,7 +211,7 @@ def _preprocess_image_like_inputs( ) original_sizes = [image.shape[-2:] for image in images] images_kwargs = kwargs.copy() - pixel_values = self._preprocess(images, **images_kwargs) + pixel_values = self._preprocess(images, **images_kwargs)["pixel_values"] reshaped_input_sizes = [image.shape[-2:] for image in images] data = { "pixel_values": pixel_values, @@ -262,54 +242,10 @@ def _preprocess_image_like_inputs( processed_segmentation_maps = self._preprocess( images=processed_segmentation_maps, **segmentation_maps_kwargs ) - data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + data["labels"] = processed_segmentation_maps["pixel_values"].squeeze(1).to(torch.int64) return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) - def _preprocess( - self, - images: list["torch.Tensor"], - do_resize: bool, - size: SizeDict, - interpolation: Optional["F_t.InterpolationMode"], - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: Optional[Union[float, list[float]]], - image_std: Optional[Union[float, list[float]]], - do_pad: bool, - pad_size: SizeDict, - disable_grouping: Optional[bool], - return_tensors: Optional[Union[str, TensorType]], - **kwargs, - ) -> Union["torch.Tensor", list["torch.Tensor"]]: - # Group images by size for batched resizing - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) - resized_images_grouped = {} - for shape, stacked_images in grouped_images.items(): - if do_resize: - stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) - resized_images_grouped[shape] = stacked_images - resized_images = reorder_images(resized_images_grouped, grouped_images_index) - - # Group images by size for further processing - # Needed in case do_resize is False, or resize returns images with different sizes - grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) - processed_images_grouped = {} - for shape, stacked_images in grouped_images.items(): - # Fused rescale and normalize - stacked_images = self.rescale_and_normalize( - stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std - ) - if do_pad: - stacked_images = self.pad_image(stacked_images, pad_size) - processed_images_grouped[shape] = stacked_images - - processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - - return processed_images - def generate_crop_boxes( self, image: "torch.Tensor", diff --git a/src/transformers/models/sam2/image_processing_sam2_fast.py b/src/transformers/models/sam2/image_processing_sam2_fast.py index 4b65bec77b57..8cb5381f0977 100644 --- a/src/transformers/models/sam2/image_processing_sam2_fast.py +++ b/src/transformers/models/sam2/image_processing_sam2_fast.py @@ -504,14 +504,6 @@ def _preprocess_image_like_inputs( return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) - def _preprocess( - self, - images: list["torch.Tensor"], - return_tensors: Optional[Union[str, TensorType]], - **kwargs, - ) -> "torch.Tensor": - return super()._preprocess(images, return_tensors=return_tensors, **kwargs).pixel_values - def generate_crop_boxes( self, image: "torch.Tensor", @@ -713,6 +705,17 @@ def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, cro """ return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh) + def pad_image(self): + raise NotImplementedError("No pad_image for SAM 2.") + + def _preprocess( + self, + images: list["torch.Tensor"], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> "torch.Tensor": + return super()._preprocess(images, return_tensors=return_tensors, **kwargs).pixel_values + def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor: """ Apply non-overlapping constraints to the object scores in pred_masks. Here we diff --git a/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py b/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py index 6f4bbd209bca..4e24bc279543 100644 --- a/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py +++ b/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py @@ -52,9 +52,6 @@ class SmolVLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ - do_pad (`bool`, *optional*): - Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest - number of patches in the batch. Padding will be applied to the bottom and right with zeros. do_image_splitting (`bool`, *optional*, defaults to `True`): Whether to split the image into sub-images concatenated with the original image. They are split into patches such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`. @@ -64,7 +61,6 @@ class SmolVLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Whether to return the row and column information of the images. """ - do_pad: Optional[bool] do_image_splitting: Optional[bool] max_image_size: Optional[dict[str, int]] return_row_col_info: Optional[bool] diff --git a/src/transformers/models/smolvlm/video_processing_smolvlm.py b/src/transformers/models/smolvlm/video_processing_smolvlm.py index 44d7ab9cef37..eda3bdb1c811 100644 --- a/src/transformers/models/smolvlm/video_processing_smolvlm.py +++ b/src/transformers/models/smolvlm/video_processing_smolvlm.py @@ -98,7 +98,8 @@ def get_resize_output_image_size( class SmolVLMVideoProcessorInitKwargs(VideosKwargs): - max_image_size: dict[str, int] = None + max_image_size: Optional[dict[str, int]] + do_pad: Optional[bool] class SmolVLMVideoProcessor(BaseVideoProcessor): diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr.py b/src/transformers/models/swin2sr/image_processing_swin2sr.py index d36ce936c2f1..76c5e907da1c 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr.py @@ -31,6 +31,7 @@ validate_preprocess_arguments, ) from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.deprecation import deprecate_kwarg logger = logging.get_logger(__name__) @@ -56,7 +57,7 @@ def __init__( do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_pad: bool = True, - pad_size: int = 8, + size_divisor: int = 8, **kwargs, ) -> None: super().__init__(**kwargs) @@ -64,7 +65,22 @@ def __init__( self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_pad = do_pad - self.pad_size = pad_size + pad_size = kwargs.get("pad_size") + self.size_divisor = size_divisor if size_divisor is not None else pad_size + + @property + def pad_size(self): + logger.warning( + "`self.pad_size` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead", + ) + return self.size_divisor + + @pad_size.setter + def pad_size(self, value): + logger.warning( + "`self.pad_size` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead", + ) + self.size_divisor = value def pad( self, @@ -108,13 +124,14 @@ def pad( ) @filter_out_non_signature_kwargs() + @deprecate_kwarg("pad_size", version="v5", new_name="size_divisor") def preprocess( self, images: ImageInput, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_pad: Optional[bool] = None, - pad_size: Optional[int] = None, + size_divisor: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -132,7 +149,7 @@ def preprocess( Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_pad (`bool`, *optional*, defaults to `True`): Whether to pad the image to make the height and width divisible by `window_size`. - pad_size (`int`, *optional*, defaults to 32): + size_divisor (`int`, *optional*, defaults to 32): The size of the sliding window for the local attention. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: @@ -157,7 +174,7 @@ def preprocess( do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_pad = do_pad if do_pad is not None else self.do_pad - pad_size = pad_size if pad_size is not None else self.pad_size + size_divisor = size_divisor if size_divisor is not None else self.size_divisor images = make_flat_list_of_images(images) @@ -169,8 +186,6 @@ def preprocess( validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, - do_pad=do_pad, - size_divisibility=pad_size, # Here the pad function simply requires pad_size. ) # All transformations expect numpy arrays. @@ -193,7 +208,7 @@ def preprocess( ] if do_pad: - images = [self.pad(image, size=pad_size, input_data_format=input_data_format) for image in images] + images = [self.pad(image, size=size_divisor, input_data_format=input_data_format) for image in images] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py b/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py index cc8235f1141e..f99ab99274f5 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py @@ -31,9 +31,13 @@ is_torch_available, is_torchvision_available, is_torchvision_v2_available, + logging, ) +from ...utils.deprecation import deprecate_kwarg +logger = logging.get_logger(__name__) + if is_torch_available(): import torch @@ -46,14 +50,12 @@ class Swin2SRFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ - do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad the image to make the height and width divisible by `window_size`. - pad_size (`int`, *optional*, defaults to `8`): - The size of the sliding window for the local attention. + size_divisor (`int`, *optional*, defaults to `8`): + The size of the sliding window for the local attention. It will be used to pad the image + to the size divisible by `size_divisor` """ - do_pad: Optional[bool] - pad_size: Optional[int] + size_divisor: Optional[int] @auto_docstring @@ -61,31 +63,48 @@ class Swin2SRImageProcessorFast(BaseImageProcessorFast): do_rescale = True rescale_factor = 1 / 255 do_pad = True - pad_size = 8 + size_divisor = 8 valid_kwargs = Swin2SRFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]): + pad_size = kwargs.pop("pad_size", None) + kwargs.setdefault("size_divisor", pad_size) super().__init__(**kwargs) + @property + def pad_size(self): + logger.warning( + "`self.pad_size` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead", + ) + return self.size_divisor + + @pad_size.setter + def pad_size(self, value): + logger.warning( + "`self.pad_size` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead", + ) + self.size_divisor = value + def preprocess(self, images: ImageInput, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]) -> BatchFeature: return super().preprocess(images, **kwargs) - def pad(self, images: "torch.Tensor", size: int) -> "torch.Tensor": + @deprecate_kwarg("size", version="v5", new_name="size_divisor") + def pad(self, images: "torch.Tensor", size_divisor: int) -> "torch.Tensor": """ - Pad an image to make the height and width divisible by `size`. + Pad an image to make the height and width divisible by `size_divisor`. Args: images (`torch.Tensor`): Images to pad. - size (`int`): + size_divisor (`int`): The size to make the height and width divisible by. Returns: `torch.Tensor`: The padded images. """ height, width = get_image_size(images, ChannelDimension.FIRST) - pad_height = (height // size + 1) * size - height - pad_width = (width // size + 1) * size - width + pad_height = (height // size_divisor + 1) * size_divisor - height + pad_width = (width // size_divisor + 1) * size_divisor - width return F.pad( images, @@ -93,13 +112,14 @@ def pad(self, images: "torch.Tensor", size: int) -> "torch.Tensor": padding_mode="symmetric", ) + @deprecate_kwarg("pad_size", version="v5", new_name="size_divisor") def _preprocess( self, images: list["torch.Tensor"], do_rescale: bool, rescale_factor: float, do_pad: bool, - pad_size: int, + size_divisor: int, disable_grouping: Optional[bool], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -110,7 +130,7 @@ def _preprocess( if do_rescale: stacked_images = self.rescale(stacked_images, scale=rescale_factor) if do_pad: - stacked_images = self.pad(stacked_images, size=pad_size) + stacked_images = self.pad(stacked_images, size_divisor=size_divisor) processed_image_grouped[shape] = stacked_images processed_images = reorder_images(processed_image_grouped, grouped_images_index) processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images diff --git a/src/transformers/models/tvp/image_processing_tvp.py b/src/transformers/models/tvp/image_processing_tvp.py index 7fa758b6f484..d3f698873d55 100644 --- a/src/transformers/models/tvp/image_processing_tvp.py +++ b/src/transformers/models/tvp/image_processing_tvp.py @@ -294,8 +294,6 @@ def _preprocess_image( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=pad_size, # here the pad() method simply requires the pad_size argument. do_center_crop=do_center_crop, crop_size=crop_size, do_resize=do_resize, diff --git a/src/transformers/models/tvp/image_processing_tvp_fast.py b/src/transformers/models/tvp/image_processing_tvp_fast.py index a3bad696c36d..b96e4991f619 100644 --- a/src/transformers/models/tvp/image_processing_tvp_fast.py +++ b/src/transformers/models/tvp/image_processing_tvp_fast.py @@ -16,7 +16,7 @@ from typing import Optional, Union -from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, @@ -55,10 +55,6 @@ class TvpFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): r""" do_flip_channel_order (`bool`, *optional*): Whether to flip the channel order of the image from RGB to BGR. - do_pad (`bool`, *optional*): - Whether to pad the image. - pad_size (`Dict[str, int]` or `SizeDict`, *optional*): - Size dictionary specifying the desired height and width for padding. constant_values (`float` or `List[float]`, *optional*): Value used to fill the padding area when `pad_mode` is `'constant'`. pad_mode (`str`, *optional*): @@ -66,8 +62,6 @@ class TvpFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ do_flip_channel_order: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[SizeDict] constant_values: Optional[Union[float, list[float]]] pad_mode: Optional[str] @@ -103,21 +97,6 @@ def preprocess( ) -> BatchFeature: return super().preprocess(videos, **kwargs) - def _further_process_kwargs( - self, - pad_size: Optional[SizeDict] = None, - **kwargs, - ) -> dict: - """ - Update kwargs that need further processing before being validated - Can be overridden by subclasses to customize the processing of kwargs. - """ - if pad_size is not None: - pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size")) - kwargs["pad_size"] = pad_size - - return super()._further_process_kwargs(**kwargs) - def _prepare_images_structure( self, images: ImageInput, @@ -135,31 +114,6 @@ def _prepare_images_structure( """ return make_nested_list_of_images(images, **kwargs) - def _pad_frames( - self, - frames: "torch.Tensor", - pad_size: Union[SizeDict, dict], - constant_values: Union[float, list[float]], - pad_mode: str, - ) -> "torch.Tensor": - """Pad frames to the specified size.""" - height, width = pad_size.height, pad_size.width - - if frames.shape[-2:] == (height, width): - return frames - - # Calculate padding - current_height, current_width = frames.shape[-2:] - pad_bottom = height - current_height - pad_right = width - current_width - - if pad_bottom < 0 or pad_right < 0: - raise ValueError("The padding size must be greater than frame size") - - # Apply padding - padding = [0, 0, pad_right, pad_bottom] # [left, top, right, bottom] - return F.pad(frames, padding, fill=constant_values, padding_mode=pad_mode) - def resize( self, image: "torch.Tensor", @@ -238,7 +192,7 @@ def _preprocess( do_rescale: bool, rescale_factor: float, do_pad: bool, - pad_size: Union[SizeDict, dict], + pad_size: SizeDict, constant_values: Union[float, list[float]], pad_mode: str, do_normalize: bool, @@ -275,7 +229,8 @@ def _preprocess( # Pad if needed if do_pad: - stacked_frames = self._pad_frames(stacked_frames, pad_size, constant_values, pad_mode) + stacked_frames = self.pad(stacked_frames, pad_size, fill_value=constant_values, pad_mode=pad_mode) + stacked_frames = torch.stack(stacked_frames, dim=0) # Flip channel order if needed (RGB to BGR) if do_flip_channel_order: diff --git a/src/transformers/models/vilt/image_processing_vilt_fast.py b/src/transformers/models/vilt/image_processing_vilt_fast.py index 3e6571f159e1..1c169994ba3f 100644 --- a/src/transformers/models/vilt/image_processing_vilt_fast.py +++ b/src/transformers/models/vilt/image_processing_vilt_fast.py @@ -51,16 +51,12 @@ class ViltFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ Args: - do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad the image. If `True`, will pad the images in the batch to the largest height and width - in the batch. Padding will be applied to the bottom and right with zeros. size_divisor (`int`, *optional*, defaults to 32): The size to make the height and width divisible by. rescale_factor (`float`, *optional*, defaults to 1/255): The factor to rescale the image by. """ - do_pad: Optional[bool] size_divisor: Optional[int] rescale_factor: Optional[float] diff --git a/src/transformers/models/vilt/processing_vilt.py b/src/transformers/models/vilt/processing_vilt.py index 5b5126ad4a85..f4f9fc9a746d 100644 --- a/src/transformers/models/vilt/processing_vilt.py +++ b/src/transformers/models/vilt/processing_vilt.py @@ -17,11 +17,17 @@ """ import warnings +from typing import Optional -from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin + + +class ViltImagesKwargs(ImagesKwargs): + size_divisor: Optional[int] class ViltProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: ViltImagesKwargs _defaults = { "text_kwargs": { "add_special_tokens": True, diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte.py b/src/transformers/models/vitmatte/image_processing_vitmatte.py index 891fdb457359..6e65a634d23d 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte.py @@ -34,6 +34,7 @@ validate_preprocess_arguments, ) from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.deprecation import deprecate_kwarg logger = logging.get_logger(__name__) @@ -60,9 +61,9 @@ class VitMatteImageProcessor(BaseImageProcessor): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden + Whether to pad the image to make the width and height divisible by `size_divisor`. Can be overridden by the `do_pad` parameter in the `preprocess` method. - size_divisibility (`int`, *optional*, defaults to 32): + size_divisor (`int`, *optional*, defaults to 32): The width and height of the image will be padded to be divisible by this number. """ @@ -76,7 +77,7 @@ def __init__( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_pad: bool = True, - size_divisibility: int = 32, + size_divisor: int = 32, **kwargs, ) -> None: super().__init__(**kwargs) @@ -86,7 +87,22 @@ def __init__( self.rescale_factor = rescale_factor self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD - self.size_divisibility = size_divisibility + size_divisibility = kwargs.get("size_divisibility") + self.size_divisor = size_divisibility if size_divisibility is not None else size_divisor + + @property + def size_divisibility(self): + logger.warning( + "`self.size_divisibility` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead" + ) + return self.size_divisor + + @size_divisibility.setter + def size_divisibility(self, value): + logger.warning( + "`self.size_divisibility` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead" + ) + self.size_divisor = value def pad_image( self, @@ -130,6 +146,7 @@ def pad_image( return image @filter_out_non_signature_kwargs() + @deprecate_kwarg("size_divisibility", version="v5", new_name="size_divisor") def preprocess( self, images: ImageInput, @@ -140,7 +157,7 @@ def preprocess( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_pad: Optional[bool] = None, - size_divisibility: Optional[int] = None, + size_divisor: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -166,7 +183,7 @@ def preprocess( Image standard deviation to use if `do_normalize` is set to `True`. do_pad (`bool`, *optional*, defaults to `self.do_pad`): Whether to pad the image. - size_divisibility (`int`, *optional*, defaults to `self.size_divisibility`): + size_divisor (`int`, *optional*, defaults to `self.size_divisor`): The size divisibility to pad the image to if `do_pad` is set to `True`. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: @@ -193,7 +210,7 @@ def preprocess( rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - size_divisibility = size_divisibility if size_divisibility is not None else self.size_divisibility + size_divisor = size_divisor if size_divisor is not None else self.size_divisor images = make_flat_list_of_images(images) trimaps = make_flat_list_of_images(trimaps, expected_ndims=2) @@ -215,8 +232,6 @@ def preprocess( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_pad=do_pad, - size_divisibility=size_divisibility, ) # All transformations expect numpy arrays. @@ -258,7 +273,7 @@ def preprocess( if do_pad: images = [ - self.pad_image(image, size_divisibility=size_divisibility, input_data_format=input_data_format) + self.pad_image(image, size_divisibility=size_divisor, input_data_format=input_data_format) for image in images ] diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py index e2cd7d331253..014a6939af5c 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py @@ -57,15 +57,11 @@ class VitMatteFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ - do_pad (`bool`, *optional*, defaults to `True`): - Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden - by the `do_pad` parameter in the `preprocess` method. - size_divisibility (`int`, *optional*, defaults to 32): + size_divisor (`int`, *optional*, defaults to 32): The width and height of the image will be padded to be divisible by this number. """ - do_pad: Optional[bool] - size_divisibility: int + size_divisor: Optional[int] @auto_docstring @@ -76,12 +72,28 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast): image_mean: Optional[Union[float, list[float]]] = IMAGENET_STANDARD_MEAN image_std: Optional[Union[float, list[float]]] = IMAGENET_STANDARD_STD do_pad: bool = True - size_divisibility: int = 32 + size_divisor: int = 32 valid_kwargs = VitMatteFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[VitMatteFastImageProcessorKwargs]) -> None: + size_divisibility = kwargs.pop("size_divisibility", None) + kwargs.setdefault("size_divisor", size_divisibility) super().__init__(**kwargs) + @property + def size_divisibility(self): + logger.warning( + "`self.size_divisibility` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead" + ) + return self.size_divisor + + @size_divisibility.setter + def size_divisibility(self, value): + logger.warning( + "`self.size_divisibility` attribute is deprecated and will be removed in v5. Use `self.size_divisor` instead" + ) + self.size_divisor = value + def _pad_image( self, images: "torch.tensor", @@ -150,10 +162,9 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, do_pad: Optional[bool] = None, - size_divisibility: Optional[int] = None, + size_divisor: Optional[int] = None, disable_grouping: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, ) -> BatchFeature: grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps, disable_grouping=disable_grouping) @@ -170,7 +181,7 @@ def _preprocess( ) stacked_images = torch.cat([stacked_images, stacked_trimaps], dim=1) if do_pad: - stacked_images = self._pad_image(stacked_images, self.size_divisibility) + stacked_images = self._pad_image(stacked_images, size_divisor) processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) diff --git a/src/transformers/models/yolos/image_processing_yolos_fast.py b/src/transformers/models/yolos/image_processing_yolos_fast.py index 4bea14b508ea..81fb0b008e0d 100644 --- a/src/transformers/models/yolos/image_processing_yolos_fast.py +++ b/src/transformers/models/yolos/image_processing_yolos_fast.py @@ -64,23 +64,12 @@ class YolosFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Controls whether to convert the annotations to the format expected by the YOLOS model. Converts the bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. - do_pad (`bool`, *optional*, defaults to `True`): - Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` - method. If `True`, padding will be applied to the bottom and right of the image with zeros. - If `pad_size` is provided, the image will be padded to the specified dimensions. - Otherwise, the image will be padded to the maximum height and width of the batch. - pad_size (`dict[str, int]`, *optional*): - The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size - provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest - height and width in the batch. return_segmentation_masks (`bool`, *optional*, defaults to `False`): Whether to return segmentation masks. """ format: Optional[Union[str, AnnotationFormat]] do_convert_annotations: Optional[bool] - do_pad: Optional[bool] - pad_size: Optional[dict[str, int]] return_segmentation_masks: Optional[bool] @@ -668,7 +657,7 @@ def _preprocess( image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], do_pad: bool, - pad_size: Optional[dict[str, int]], + pad_size: Optional[SizeDict], format: Optional[Union[str, AnnotationFormat]], return_tensors: Optional[Union[str, TensorType]], **kwargs, @@ -737,7 +726,7 @@ def _preprocess( if do_pad: # depends on all resized image shapes so we need another loop if pad_size is not None: - padded_size = (pad_size["height"], pad_size["width"]) + padded_size = (pad_size.height, pad_size.width) else: padded_size = get_max_height_width(images) diff --git a/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py b/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py index 793c386fdc75..c89ec8b2ebf1 100644 --- a/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py +++ b/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py @@ -70,8 +70,6 @@ class ZoeDepthFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ - do_pad (`bool`, *optional*, defaults to `True`): - Whether to apply pad the input. keep_aspect_ratio (`bool`, *optional*, defaults to `True`): If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for both dimensions. This ensures that the image is scaled down as little as possible while still fitting @@ -85,7 +83,6 @@ class ZoeDepthFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): Can be overridden by `ensure_multiple_of` in `preprocess`. """ - do_pad: Optional[bool] keep_aspect_ratio: Optional[bool] ensure_multiple_of: Optional[int] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 3130d0ded34f..86cdb372034c 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -168,8 +168,6 @@ class methods and docstrings. Whether to resize the image. size (`dict[str, int]`, *optional*): Resize the shorter side of the input to `size["shortest_edge"]`. - size_divisor (`int`, *optional*): - The size by which to make sure both the height and width can be divided. crop_size (`dict[str, int]`, *optional*): Desired output size when applying center-cropping. resample (`PILImageResampling`, *optional*): @@ -200,7 +198,6 @@ class methods and docstrings. do_resize: Optional[bool] size: Optional[dict[str, int]] - size_divisor: Optional[int] crop_size: Optional[dict[str, int]] resample: Optional[Union["PILImageResampling", int]] do_rescale: Optional[bool] @@ -229,8 +226,6 @@ class VideosKwargs(TypedDict, total=False): Resize the shorter side of the input to `size["shortest_edge"]`. default_to_square (`bool`, *optional*, defaults to `self.default_to_square`): Whether to default to a square when resizing, if size is an int. - size_divisor (`int`, *optional*): - The size by which to make sure both the height and width can be divided. resample (`PILImageResampling`, *optional*): Resampling filter to use if resizing the video. do_rescale (`bool`, *optional*): @@ -243,8 +238,6 @@ class VideosKwargs(TypedDict, total=False): Mean to use if normalizing the video. image_std (`float` or `list[float]`, *optional*): Standard deviation to use if normalizing the video. - do_pad (`bool`, *optional*): - Whether to pad the video to the `(max_height, max_width)` of the videos in the batch. do_center_crop (`bool`, *optional*): Whether to center crop the video. do_sample_frames (`bool`, *optional*): @@ -268,7 +261,6 @@ class VideosKwargs(TypedDict, total=False): do_convert_rgb: Optional[bool] do_resize: Optional[bool] size: Optional[dict[str, int]] - size_divisor: Optional[int] default_to_square: Optional[bool] resample: Optional["PILImageResampling"] do_rescale: Optional[bool] @@ -276,7 +268,6 @@ class VideosKwargs(TypedDict, total=False): do_normalize: Optional[bool] image_mean: Optional[Union[float, list[float]]] image_std: Optional[Union[float, list[float]]] - do_pad: Optional[bool] do_center_crop: Optional[bool] crop_size: Optional[dict[str, int]] data_format: Optional[ChannelDimension] @@ -655,6 +646,18 @@ def to_dict(self, legacy_serialization=True) -> dict[str, Any]: if "chat_template" in output: del output["chat_template"] + def cast_array_to_list(dictionary): + """ + Numpy arrays are not serialiazable but can be in pre-processing dicts. + This function casts arrays to list, recusring through the nested configs as well. + """ + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + elif isinstance(value, dict): + dictionary[key] = cast_array_to_list(value) + return dictionary + # Serialize attributes as a dict output = { k: v.to_dict() if isinstance(v, PushToHubMixin) else v @@ -667,6 +670,7 @@ def to_dict(self, legacy_serialization=True) -> dict[str, Any]: ) # remove `PushToHubMixin` objects ) } + output = cast_array_to_list(output) # Special case, add `audio_tokenizer` dict which points to model weights and path if not legacy_serialization and "audio_tokenizer" in output: diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index a9d9a8cba788..0847859450ea 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -131,6 +131,23 @@ class ImageProcessorArgs: "shape": None, } + do_pad = { + "description": """ + Whether to pad the image. Padding is done either to the largest size in the batch + or to a fixed square size per image. The exact padding strategy depends on the model. + """, + "shape": None, + } + + pad_size = { + "description": """ + The size in `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. Applied only when `do_pad=True.` + """, + "shape": None, + } + do_rescale = { "description": """ Whether to rescale the image. diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py index 43d9e2bfd26e..9f6545ebe10e 100644 --- a/src/transformers/video_processing_utils.py +++ b/src/transformers/video_processing_utils.py @@ -95,8 +95,6 @@ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): Whether to center crop the video to the specified `crop_size`. Can be overridden by `do_center_crop` in the `preprocess` method. - do_pad (`bool`, *optional*): - Whether to pad the video to the `(max_height, max_width)` of the videos in the batch. crop_size (`dict[str, int]` *optional*, defaults to `self.crop_size`): Size of the output video after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` method. @@ -164,7 +162,6 @@ class BaseVideoProcessor(BaseImageProcessorFast): crop_size = None do_resize = None do_center_crop = None - do_pad = None do_rescale = None rescale_factor = 1 / 255 do_normalize = None @@ -401,12 +398,10 @@ def _preprocess( do_convert_rgb: bool, do_resize: bool, size: SizeDict, - size_divisor: Optional[int], interpolation: Optional["F.InterpolationMode"], do_center_crop: bool, crop_size: SizeDict, do_rescale: bool, - do_pad: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, list[float]]], @@ -421,9 +416,7 @@ def _preprocess( if do_convert_rgb: stacked_videos = self.convert_to_rgb(stacked_videos) if do_resize: - stacked_videos = self.resize( - stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation - ) + stacked_videos = self.resize(stacked_videos, size=size, interpolation=interpolation) resized_videos_grouped[shape] = stacked_videos resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) diff --git a/tests/models/gemma3n/test_processing_gemma3n.py b/tests/models/gemma3n/test_processing_gemma3n.py index 320b821d6f79..2fbe7e79d3e5 100644 --- a/tests/models/gemma3n/test_processing_gemma3n.py +++ b/tests/models/gemma3n/test_processing_gemma3n.py @@ -66,15 +66,12 @@ def test_save_load_pretrained_default(self): tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor ) - processor.save_pretrained(self.tmpdirname) + processor.save_pretrained(self.tmpdirname, legacy_serialization=False) processor = Gemma3nProcessor.from_pretrained(self.tmpdirname) self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast) self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) - # `disable_grouping` is a new attribute that got added on main while gemma3n was being released - so was - # not part of the saved processor - del processor.feature_extractor.disable_grouping self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) @@ -86,7 +83,7 @@ def test_save_load_pretrained_additional_features(self): processor = Gemma3nProcessor( tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor ) - processor.save_pretrained(self.tmpdirname) + processor.save_pretrained(self.tmpdirname, legacy_serialization=False) tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS-BOS)", eos_token="(EOS-EOS)") feature_extractor_add_kwargs = self.get_feature_extractor(dither=5.0, padding_value=1.0) @@ -98,9 +95,6 @@ def test_save_load_pretrained_additional_features(self): self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast) - # `disable_grouping` is a new attribute that got added on main while gemma3n was being released - so was - # not part of the saved processor - del processor.feature_extractor.disable_grouping self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor) diff --git a/tests/models/janus/test_processing_janus.py b/tests/models/janus/test_processing_janus.py index 7e1b025721dc..73212e3ec4b3 100644 --- a/tests/models/janus/test_processing_janus.py +++ b/tests/models/janus/test_processing_janus.py @@ -457,7 +457,7 @@ def test_processor_postprocess(self): orig_image_input = self.prepare_image_inputs() orig_image = np.array(orig_image_input).transpose(2, 0, 1) - inputs = processor(text=input_str, images=orig_image, do_resize=False, return_tensors="np") + inputs = processor(text=input_str, images=orig_image, do_resize=False, do_pad=False, return_tensors="np") normalized_image_input = inputs.pixel_values unnormalized_images = processor.postprocess(normalized_image_input, return_tensors="np")["pixel_values"] diff --git a/tests/models/swin2sr/test_image_processing_swin2sr.py b/tests/models/swin2sr/test_image_processing_swin2sr.py index eecb023c29a0..2cf3edaf4386 100644 --- a/tests/models/swin2sr/test_image_processing_swin2sr.py +++ b/tests/models/swin2sr/test_image_processing_swin2sr.py @@ -48,7 +48,7 @@ def __init__( do_rescale=True, rescale_factor=1 / 255, do_pad=True, - pad_size=8, + size_divisor=8, ): self.parent = parent self.batch_size = batch_size @@ -59,14 +59,14 @@ def __init__( self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_pad = do_pad - self.pad_size = pad_size + self.size_divisor = size_divisor def prepare_image_processor_dict(self): return { "do_rescale": self.do_rescale, "rescale_factor": self.rescale_factor, "do_pad": self.do_pad, - "pad_size": self.pad_size, + "size_divisor": self.size_divisor, } def expected_output_image_shape(self, images): @@ -79,8 +79,8 @@ def expected_output_image_shape(self, images): else: input_height, input_width = img.shape[-2:] - pad_height = (input_height // self.pad_size + 1) * self.pad_size - input_height - pad_width = (input_width // self.pad_size + 1) * self.pad_size - input_width + pad_height = (input_height // self.size_divisor + 1) * self.size_divisor - input_height + pad_width = (input_width // self.size_divisor + 1) * self.size_divisor - input_width return self.num_channels, input_height + pad_height, input_width + pad_width @@ -116,11 +116,12 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processing, "do_rescale")) self.assertTrue(hasattr(image_processing, "rescale_factor")) self.assertTrue(hasattr(image_processing, "do_pad")) - self.assertTrue(hasattr(image_processing, "pad_size")) + self.assertTrue(hasattr(image_processing, "size_divisor")) + self.assertTrue(hasattr(image_processing, "pad_size")) # deprecated but should be available def calculate_expected_size(self, image): old_height, old_width = get_image_size(image) - size = self.image_processor_tester.pad_size + size = self.image_processor_tester.size_divisor pad_height = (old_height // size + 1) * size - old_height pad_width = (old_width // size + 1) * size - old_width diff --git a/tests/models/tvp/test_image_processing_tvp.py b/tests/models/tvp/test_image_processing_tvp.py index 28581290e9d1..c2c8b81dfc0a 100644 --- a/tests/models/tvp/test_image_processing_tvp.py +++ b/tests/models/tvp/test_image_processing_tvp.py @@ -222,15 +222,15 @@ def test_call_numpy(self): # Test not batched input expected_height, expected_width = self.image_processor_tester.get_expected_values(video_inputs) encoded_videos = image_processing(test_inputs[0], return_tensors="pt").pixel_values - self.assertEqual( - encoded_videos.shape, - ( + self.assertListEqual( + list(encoded_videos.shape), + [ 1, self.image_processor_tester.num_frames, self.image_processor_tester.num_channels, expected_height, expected_width, - ), + ], ) # Test batched @@ -238,15 +238,15 @@ def test_call_numpy(self): video_inputs, batched=True ) encoded_videos = image_processing(test_inputs, return_tensors="pt").pixel_values - self.assertEqual( - encoded_videos.shape, - ( + self.assertListEqual( + list(encoded_videos.shape), + [ self.image_processor_tester.batch_size, self.image_processor_tester.num_frames, self.image_processor_tester.num_channels, expected_height, expected_width, - ), + ], ) def test_call_numpy_4_channels(self): @@ -276,15 +276,15 @@ def test_call_numpy_4_channels(self): encoded_videos = image_processing( test_inputs[0], return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first" ).pixel_values - self.assertEqual( - encoded_videos.shape, - ( + self.assertListEqual( + list(encoded_videos.shape), + [ 1, self.image_processor_tester.num_frames, self.image_processor_tester.num_channels, expected_height, expected_width, - ), + ], ) # Test batched @@ -294,15 +294,15 @@ def test_call_numpy_4_channels(self): encoded_videos = image_processing( test_inputs, return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first" ).pixel_values - self.assertEqual( - encoded_videos.shape, - ( + self.assertListEqual( + list(encoded_videos.shape), + [ self.image_processor_tester.batch_size, self.image_processor_tester.num_frames, self.image_processor_tester.num_channels, expected_height, expected_width, - ), + ], ) self.image_processor_tester.num_channels = 3 diff --git a/tests/models/vitmatte/test_image_processing_vitmatte.py b/tests/models/vitmatte/test_image_processing_vitmatte.py index dc5597b1918b..a103c33a9cca 100644 --- a/tests/models/vitmatte/test_image_processing_vitmatte.py +++ b/tests/models/vitmatte/test_image_processing_vitmatte.py @@ -60,7 +60,7 @@ def __init__( do_rescale=True, rescale_factor=0.5, do_pad=True, - size_divisibility=10, + size_divisor=10, do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], @@ -74,7 +74,7 @@ def __init__( self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_pad = do_pad - self.size_divisibility = size_divisibility + self.size_divisor = size_divisor self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std @@ -87,7 +87,7 @@ def prepare_image_processor_dict(self): "do_rescale": self.do_rescale, "rescale_factor": self.rescale_factor, "do_pad": self.do_pad, - "size_divisibility": self.size_divisibility, + "size_divisor": self.size_divisor, } def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): @@ -125,6 +125,8 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processing, "do_rescale")) self.assertTrue(hasattr(image_processing, "rescale_factor")) self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "size_divisor")) + # Check size_divisibility for BC, the image proccessor has to have an atribute self.assertTrue(hasattr(image_processing, "size_divisibility")) def test_call_numpy(self): @@ -141,8 +143,8 @@ def test_call_numpy(self): encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values # Verify that width and height can be divided by size_divisibility and that correct dimensions got merged - self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) - self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisor == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisor == 0) self.assertTrue(encoded_images.shape[-3] == 4) def test_call_pytorch(self): @@ -160,8 +162,8 @@ def test_call_pytorch(self): encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values # Verify that width and height can be divided by size_divisibility and that correct dimensions got merged - self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) - self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisor == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisor == 0) self.assertTrue(encoded_images.shape[-3] == 4) # create batched tensors @@ -180,8 +182,8 @@ def test_call_pytorch(self): encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values # Verify that width and height can be divided by size_divisibility and that correct dimensions got merged - self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) - self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisor == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisor == 0) self.assertTrue(encoded_images.shape[-3] == 4) def test_call_pil(self): @@ -198,8 +200,8 @@ def test_call_pil(self): encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values # Verify that width and height can be divided by size_divisibility and that correct dimensions got merged - self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) - self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisor == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisor == 0) self.assertTrue(encoded_images.shape[-3] == 4) def test_call_numpy_4_channels(self): @@ -224,8 +226,8 @@ def test_call_numpy_4_channels(self): ).pixel_values # Verify that width and height can be divided by size_divisibility and that correct dimensions got merged - self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) - self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisor == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisor == 0) self.assertTrue(encoded_images.shape[-3] == 5) def test_padding_slow(self):