@@ -414,78 +414,80 @@ def __init__(
414
414
_check_padding_arg (padding )
415
415
_check_padding_mode_arg (padding_mode )
416
416
417
- self .padding = padding
417
+ self .padding = F . _geometry . _parse_pad_padding ( padding ) if padding else None # type: ignore[arg-type]
418
418
self .pad_if_needed = pad_if_needed
419
419
self .fill = _setup_fill_arg (fill )
420
420
self .padding_mode = padding_mode
421
421
422
422
def _get_params (self , sample : Any ) -> Dict [str , Any ]:
423
- _ , height , width = query_chw (sample )
423
+ _ , padded_height , padded_width = query_chw (sample )
424
424
425
425
if self .padding is not None :
426
- # update height, width with static padding data
427
- padding = self .padding
428
- if isinstance (padding , Sequence ):
429
- padding = list (padding )
430
- pad_left , pad_right , pad_top , pad_bottom = F ._geometry ._parse_pad_padding (padding )
431
- height += pad_top + pad_bottom
432
- width += pad_left + pad_right
433
-
434
- output_height , output_width = self .size
435
- # We have to store maybe padded image size for pad_if_needed branch in _transform
436
- input_height , input_width = height , width
426
+ pad_left , pad_right , pad_top , pad_bottom = self .padding
427
+ padded_height += pad_top + pad_bottom
428
+ padded_width += pad_left + pad_right
429
+ else :
430
+ pad_left = pad_right = pad_top = pad_bottom = 0
431
+
432
+ cropped_height , cropped_width = self .size
437
433
438
434
if self .pad_if_needed :
439
- # pad width if needed
440
- if width < output_width :
441
- width += 2 * (output_width - width )
442
- # pad height if needed
443
- if height < output_height :
444
- height += 2 * (output_height - height )
445
-
446
- if height < output_height or width < output_width :
435
+ if padded_height < cropped_height :
436
+ diff = cropped_height - padded_height
437
+
438
+ pad_top += diff
439
+ pad_bottom += diff
440
+ padded_height += 2 * diff
441
+
442
+ if padded_width < cropped_width :
443
+ diff = cropped_width - padded_width
444
+
445
+ pad_left += diff
446
+ pad_right += diff
447
+ padded_width += 2 * diff
448
+
449
+ if padded_height < cropped_height or padded_width < cropped_width :
447
450
raise ValueError (
448
- f"Required crop size { (output_height , output_width )} is larger then input image size { (height , width )} "
451
+ f"Required crop size { (cropped_height , cropped_width )} is larger than "
452
+ f"{ 'padded ' if self .padding is not None else '' } input image size { (padded_height , padded_width )} ."
449
453
)
450
454
451
- if width == output_width and height == output_height :
452
- return dict (top = 0 , left = 0 , height = height , width = width , input_width = input_width , input_height = input_height )
455
+ # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
456
+ padding = [pad_left , pad_top , pad_right , pad_bottom ]
457
+ needs_pad = any (padding )
453
458
454
- top = torch .randint (0 , height - output_height + 1 , size = (1 ,)).item ()
455
- left = torch .randint (0 , width - output_width + 1 , size = (1 ,)).item ()
459
+ needs_vert_crop , top = (
460
+ (True , int (torch .randint (0 , padded_height - cropped_height + 1 , size = ())))
461
+ if padded_height > cropped_height
462
+ else (False , 0 )
463
+ )
464
+ needs_horz_crop , left = (
465
+ (True , int (torch .randint (0 , padded_width - cropped_width + 1 , size = ())))
466
+ if padded_width > cropped_width
467
+ else (False , 0 )
468
+ )
456
469
457
470
return dict (
471
+ needs_crop = needs_vert_crop or needs_horz_crop ,
458
472
top = top ,
459
473
left = left ,
460
- height = output_height ,
461
- width = output_width ,
462
- input_width = input_width ,
463
- input_height = input_height ,
474
+ height = cropped_height ,
475
+ width = cropped_width ,
476
+ needs_pad = needs_pad ,
477
+ padding = padding ,
464
478
)
465
479
466
480
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
467
- # TODO: (PERF) check for speed optimization if we avoid repeated pad calls
468
- fill = self .fill [type (inpt )]
469
- fill = F ._geometry ._convert_fill_arg (fill )
481
+ if params [ "needs_pad" ]:
482
+ fill = self .fill [type (inpt )]
483
+ fill = F ._geometry ._convert_fill_arg (fill )
470
484
471
- if self .padding is not None :
472
- # This cast does Sequence[int] -> List[int] and is required to make mypy happy
473
- padding = self .padding
474
- if not isinstance (padding , int ):
475
- padding = list (padding )
485
+ inpt = F .pad (inpt , padding = params ["padding" ], fill = fill , padding_mode = self .padding_mode )
476
486
477
- inpt = F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
487
+ if params ["needs_crop" ]:
488
+ inpt = F .crop (inpt , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ])
478
489
479
- if self .pad_if_needed :
480
- input_width , input_height = params ["input_width" ], params ["input_height" ]
481
- if input_width < self .size [1 ]:
482
- padding = [self .size [1 ] - input_width , 0 ]
483
- inpt = F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
484
- if input_height < self .size [0 ]:
485
- padding = [0 , self .size [0 ] - input_height ]
486
- inpt = F .pad (inpt , padding = padding , fill = fill , padding_mode = self .padding_mode )
487
-
488
- return F .crop (inpt , top = params ["top" ], left = params ["left" ], height = params ["height" ], width = params ["width" ])
490
+ return inpt
489
491
490
492
491
493
class RandomPerspective (_RandomApplyTransform ):
0 commit comments