Skip to content

Commit 8e837f6

Browse files
authored
Consistent naming for images kwargs (#40834)
* use consistent naming for padding * no validation on pad size * add warnings * fix * fox copies * another fix * fix some tests * fix more tests * fix lasts tests * fix copies * better docstring * delete print
1 parent eb04363 commit 8e837f6

File tree

72 files changed

+619
-574
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+619
-574
lines changed

src/transformers/image_processing_utils_fast.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ def validate_fast_preprocess_arguments(
7979
do_normalize: Optional[bool] = None,
8080
image_mean: Optional[Union[float, list[float]]] = None,
8181
image_std: Optional[Union[float, list[float]]] = None,
82-
do_pad: Optional[bool] = None,
83-
size_divisibility: Optional[int] = None,
8482
do_center_crop: Optional[bool] = None,
8583
crop_size: Optional[SizeDict] = None,
8684
do_resize: Optional[bool] = None,
@@ -99,8 +97,6 @@ def validate_fast_preprocess_arguments(
9997
do_normalize=do_normalize,
10098
image_mean=image_mean,
10199
image_std=image_std,
102-
do_pad=do_pad,
103-
size_divisibility=size_divisibility,
104100
do_center_crop=do_center_crop,
105101
crop_size=crop_size,
106102
do_resize=do_resize,
@@ -181,6 +177,8 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
181177
do_normalize: Optional[bool]
182178
image_mean: Optional[Union[float, list[float]]]
183179
image_std: Optional[Union[float, list[float]]]
180+
do_pad: Optional[bool]
181+
pad_size: Optional[dict[str, int]]
184182
do_convert_rgb: Optional[bool]
185183
return_tensors: Optional[Union[str, TensorType]]
186184
data_format: Optional[ChannelDimension]
@@ -199,6 +197,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
199197
crop_size = None
200198
do_resize = None
201199
do_center_crop = None
200+
do_pad = None
201+
pad_size = None
202202
do_rescale = None
203203
rescale_factor = 1 / 255
204204
do_normalize = None
@@ -222,6 +222,9 @@ def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]):
222222
)
223223
crop_size = kwargs.pop("crop_size", self.crop_size)
224224
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
225+
pad_size = kwargs.pop("pad_size", self.pad_size)
226+
self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None
227+
225228
for key in self.valid_kwargs.__annotations__:
226229
kwarg = kwargs.pop(key, None)
227230
if kwarg is not None:
@@ -239,6 +242,74 @@ def is_fast(self) -> bool:
239242
"""
240243
return True
241244

245+
def pad(
246+
self,
247+
images: "torch.Tensor",
248+
pad_size: SizeDict = None,
249+
fill_value: Optional[int] = 0,
250+
padding_mode: Optional[str] = "constant",
251+
return_mask: Optional[bool] = False,
252+
disable_grouping: Optional[bool] = False,
253+
**kwargs,
254+
) -> "torch.Tensor":
255+
"""
256+
Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.
257+
258+
Args:
259+
images (`torch.Tensor`):
260+
Images to pad.
261+
pad_size (`SizeDict`, *optional*):
262+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
263+
fill_value (`int`, *optional*, defaults to `0`):
264+
The constant value used to fill the padded area.
265+
padding_mode (`str`, *optional*, defaults to "constant"):
266+
The padding mode to use. Can be any of the modes supported by
267+
`torch.nn.functional.pad` (e.g. constant, reflection, replication).
268+
return_mask (`bool`, *optional*, defaults to `False`):
269+
Whether to return a pixel mask to denote padded regions.
270+
disable_grouping (`bool`, *optional*, defaults to `False`):
271+
Whether to disable grouping of images by size.
272+
273+
Returns:
274+
`torch.Tensor`: The resized image.
275+
"""
276+
if pad_size is not None:
277+
if not (pad_size.height and pad_size.width):
278+
raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
279+
pad_size = (pad_size.height, pad_size.width)
280+
else:
281+
pad_size = get_max_height_width(images)
282+
283+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
284+
processed_images_grouped = {}
285+
processed_masks_grouped = {}
286+
for shape, stacked_images in grouped_images.items():
287+
image_size = stacked_images.shape[-2:]
288+
padding_height = pad_size[0] - image_size[0]
289+
padding_width = pad_size[1] - image_size[1]
290+
if padding_height < 0 or padding_width < 0:
291+
raise ValueError(
292+
f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
293+
f"image size. Got pad_size={pad_size}, image_size={image_size}."
294+
)
295+
if image_size != pad_size:
296+
padding = (0, 0, padding_width, padding_height)
297+
stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
298+
processed_images_grouped[shape] = stacked_images
299+
300+
if return_mask:
301+
# keep only one from the channel dimension in pixel mask
302+
stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
303+
stacked_masks[..., : image_size[0], : image_size[1]] = 1
304+
processed_masks_grouped[shape] = stacked_masks
305+
306+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
307+
if return_mask:
308+
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)
309+
return processed_images, processed_masks
310+
311+
return processed_images
312+
242313
def resize(
243314
self,
244315
image: "torch.Tensor",
@@ -577,6 +648,7 @@ def _further_process_kwargs(
577648
self,
578649
size: Optional[SizeDict] = None,
579650
crop_size: Optional[SizeDict] = None,
651+
pad_size: Optional[SizeDict] = None,
580652
default_to_square: Optional[bool] = None,
581653
image_mean: Optional[Union[float, list[float]]] = None,
582654
image_std: Optional[Union[float, list[float]]] = None,
@@ -593,6 +665,8 @@ def _further_process_kwargs(
593665
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
594666
if crop_size is not None:
595667
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
668+
if pad_size is not None:
669+
pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
596670
if isinstance(image_mean, list):
597671
image_mean = tuple(image_mean)
598672
if isinstance(image_std, list):
@@ -602,6 +676,7 @@ def _further_process_kwargs(
602676

603677
kwargs["size"] = size
604678
kwargs["crop_size"] = crop_size
679+
kwargs["pad_size"] = pad_size
605680
kwargs["image_mean"] = image_mean
606681
kwargs["image_std"] = image_std
607682
kwargs["data_format"] = data_format
@@ -714,6 +789,8 @@ def _preprocess(
714789
do_normalize: bool,
715790
image_mean: Optional[Union[float, list[float]]],
716791
image_std: Optional[Union[float, list[float]]],
792+
do_pad: Optional[bool],
793+
pad_size: Optional[SizeDict],
717794
disable_grouping: Optional[bool],
718795
return_tensors: Optional[Union[str, TensorType]],
719796
**kwargs,
@@ -739,10 +816,12 @@ def _preprocess(
739816
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
740817
)
741818
processed_images_grouped[shape] = stacked_images
742-
743819
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
744-
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
745820

821+
if do_pad:
822+
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
823+
824+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
746825
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
747826

748827
def to_dict(self):

src/transformers/image_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def validate_preprocess_arguments(
525525
image_mean: Optional[Union[float, list[float]]] = None,
526526
image_std: Optional[Union[float, list[float]]] = None,
527527
do_pad: Optional[bool] = None,
528-
size_divisibility: Optional[int] = None,
528+
pad_size: Optional[Union[dict[str, int], int]] = None,
529529
do_center_crop: Optional[bool] = None,
530530
crop_size: Optional[dict[str, int]] = None,
531531
do_resize: Optional[bool] = None,
@@ -544,10 +544,15 @@ def validate_preprocess_arguments(
544544
if do_rescale and rescale_factor is None:
545545
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
546546

547-
if do_pad and size_divisibility is None:
548-
# Here, size_divisor might be passed as the value of size
547+
if do_pad and pad_size is None:
548+
# Processors pad images using different args depending on the model, so the below check is pointless
549+
# but we keep it for BC for now. TODO: remove in v5
550+
# Usually padding can be called with:
551+
# - "pad_size/size" if we're padding to specific values
552+
# - "size_divisor" if we're padding to any value divisible by X
553+
# - "None" if we're padding to the maximum size image in batch
549554
raise ValueError(
550-
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
555+
"Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
551556
)
552557

553558
if do_normalize and (image_mean is None or image_std is None):

src/transformers/models/bridgetower/image_processing_bridgetower.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,6 @@ def preprocess(
480480
do_normalize=do_normalize,
481481
image_mean=image_mean,
482482
image_std=image_std,
483-
do_pad=do_pad,
484-
size_divisibility=size_divisor,
485483
do_center_crop=do_center_crop,
486484
crop_size=crop_size,
487485
do_resize=do_resize,

src/transformers/models/bridgetower/image_processing_bridgetower_fast.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
SizeDict,
2626
TensorType,
2727
Unpack,
28-
get_max_height_width,
2928
group_images_by_shape,
3029
reorder_images,
3130
)
@@ -99,13 +98,9 @@ class BridgeTowerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
9998
size_divisor (`int`, *optional*, defaults to 32):
10099
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
101100
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
102-
do_pad (`bool`, *optional*, defaults to `True`):
103-
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
104-
the `do_pad` parameter in the `preprocess` method.
105101
"""
106102

107103
size_divisor: Optional[int]
108-
do_pad: Optional[bool]
109104

110105

111106
@auto_docstring
@@ -224,59 +219,6 @@ def _pad_image(
224219
)
225220
return padded_image
226221

227-
def pad(
228-
self,
229-
images: list["torch.Tensor"],
230-
constant_values: Union[float, Iterable[float]] = 0,
231-
return_pixel_mask: bool = True,
232-
disable_grouping: Optional[bool] = False,
233-
) -> tuple:
234-
"""
235-
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
236-
in the batch and optionally returns their corresponding pixel mask.
237-
238-
Args:
239-
image (`torch.Tensor`):
240-
Image to pad.
241-
constant_values (`float` or `Iterable[float]`, *optional*):
242-
The value to use for the padding if `mode` is `"constant"`.
243-
return_pixel_mask (`bool`, *optional*, defaults to `True`):
244-
Whether to return a pixel mask.
245-
disable_grouping (`bool`, *optional*, defaults to `False`):
246-
Whether to disable grouping of images by size.
247-
return_tensors (`str` or `TensorType`, *optional*):
248-
The type of tensors to return. Can be one of:
249-
- Unset: Return a list of `np.ndarray`.
250-
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
251-
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
252-
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
253-
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
254-
"""
255-
pad_size = get_max_height_width(images)
256-
257-
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
258-
processed_images_grouped = {}
259-
processed_masks_grouped = {}
260-
for shape, stacked_images in grouped_images.items():
261-
stacked_images = self._pad_image(
262-
stacked_images,
263-
pad_size,
264-
constant_values=constant_values,
265-
)
266-
processed_images_grouped[shape] = stacked_images
267-
268-
if return_pixel_mask:
269-
stacked_masks = make_pixel_mask(image=stacked_images, output_size=pad_size)
270-
processed_masks_grouped[shape] = stacked_masks
271-
272-
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
273-
274-
processed_masks = None
275-
if return_pixel_mask:
276-
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)
277-
278-
return processed_images, processed_masks
279-
280222
def _preprocess(
281223
self,
282224
images: list["torch.Tensor"],
@@ -325,7 +267,7 @@ def _preprocess(
325267
data = {}
326268
if do_pad:
327269
processed_images, processed_masks = self.pad(
328-
processed_images, return_pixel_mask=True, disable_grouping=disable_grouping
270+
processed_images, return_mask=True, disable_grouping=disable_grouping
329271
)
330272
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
331273
data["pixel_mask"] = processed_masks

src/transformers/models/bridgetower/processing_bridgetower.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616
Processor class for BridgeTower.
1717
"""
1818

19-
from ...processing_utils import ProcessingKwargs, ProcessorMixin
19+
from typing import Optional
20+
21+
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
22+
23+
24+
class BridgeTowerImagesKwargs(ImagesKwargs):
25+
size_divisor: Optional[int]
2026

2127

2228
class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
29+
images_kwargs: BridgeTowerImagesKwargs
2330
_defaults = {
2431
"text_kwargs": {
2532
"add_special_tokens": True,

src/transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _preprocess(
227227
image_std: Optional[Union[float, list[float]]],
228228
disable_grouping: Optional[bool],
229229
return_tensors: Optional[Union[str, TensorType]],
230+
**kwargs,
230231
) -> BatchFeature:
231232
if crop_to_patches:
232233
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)

src/transformers/models/conditional_detr/image_processing_conditional_detr_fast.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,12 @@ class ConditionalDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
7474
Controls whether to convert the annotations to the format expected by the CONDITIONAL_DETR model. Converts the
7575
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
7676
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
77-
do_pad (`bool`, *optional*, defaults to `True`):
78-
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
79-
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
80-
If `pad_size` is provided, the image will be padded to the specified dimensions.
81-
Otherwise, the image will be padded to the maximum height and width of the batch.
82-
pad_size (`dict[str, int]`, *optional*):
83-
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
84-
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
85-
height and width in the batch.
8677
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
8778
Whether to return segmentation masks.
8879
"""
8980

9081
format: Optional[Union[str, AnnotationFormat]]
9182
do_convert_annotations: Optional[bool]
92-
do_pad: Optional[bool]
93-
pad_size: Optional[dict[str, int]]
9483
return_segmentation_masks: Optional[bool]
9584

9685

@@ -629,7 +618,7 @@ def _preprocess(
629618
image_mean: Optional[Union[float, list[float]]],
630619
image_std: Optional[Union[float, list[float]]],
631620
do_pad: bool,
632-
pad_size: Optional[dict[str, int]],
621+
pad_size: Optional[SizeDict],
633622
format: Optional[Union[str, AnnotationFormat]],
634623
return_tensors: Optional[Union[str, TensorType]],
635624
**kwargs,
@@ -698,7 +687,7 @@ def _preprocess(
698687
if do_pad:
699688
# depends on all resized image shapes so we need another loop
700689
if pad_size is not None:
701-
padded_size = (pad_size["height"], pad_size["width"])
690+
padded_size = (pad_size.height, pad_size.width)
702691
else:
703692
padded_size = get_max_height_width(images)
704693

src/transformers/models/convnext/image_processing_convnext_fast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _preprocess(
155155
image_std: Optional[Union[float, list[float]]],
156156
disable_grouping: Optional[bool],
157157
return_tensors: Optional[Union[str, TensorType]],
158+
**kwargs,
158159
) -> BatchFeature:
159160
# Group images by size for batched resizing
160161
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)

0 commit comments

Comments
 (0)