@@ -355,6 +355,29 @@ def _hsv2rgb(img):
355355 return torch .einsum ("ijk, xijk -> xjk" , mask .to (dtype = img .dtype ), a4 )
356356
357357
358+ def _pad_symmetric (img : Tensor , padding : List [int ]) -> Tensor :
359+ # padding is left, right, top, bottom
360+ in_sizes = img .size ()
361+
362+ x_indices = [i for i in range (in_sizes [- 1 ])] # [0, 1, 2, 3, ...]
363+ left_indices = [i for i in range (padding [0 ] - 1 , - 1 , - 1 )] # e.g. [3, 2, 1, 0]
364+ right_indices = [- (i + 1 ) for i in range (padding [1 ])] # e.g. [-1, -2, -3]
365+ x_indices = torch .tensor (left_indices + x_indices + right_indices )
366+
367+ y_indices = [i for i in range (in_sizes [- 2 ])]
368+ top_indices = [i for i in range (padding [2 ] - 1 , - 1 , - 1 )]
369+ bottom_indices = [- (i + 1 ) for i in range (padding [3 ])]
370+ y_indices = torch .tensor (top_indices + y_indices + bottom_indices )
371+
372+ ndim = img .ndim
373+ if ndim == 3 :
374+ return img [:, y_indices [:, None ], x_indices [None , :]]
375+ elif ndim == 4 :
376+ return img [:, :, y_indices [:, None ], x_indices [None , :]]
377+ else :
378+ raise RuntimeError ("Symmetric padding of N-D tensors are not supported yet" )
379+
380+
358381def pad (img : Tensor , padding : List [int ], fill : int = 0 , padding_mode : str = "constant" ) -> Tensor :
359382 r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
360383
@@ -380,6 +403,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
380403 padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
381404 will result in [3, 2, 1, 2, 3, 4, 3, 2]
382405
406+ - symmetric: pads with reflection of image (repeating the last value on the edge)
407+
408+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
409+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
410+
383411 Returns:
384412 Tensor: Padded image.
385413 """
@@ -400,8 +428,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
400428 raise ValueError ("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
401429 "{} element tuple" .format (len (padding )))
402430
403- if padding_mode not in ["constant" , "edge" , "reflect" ]:
404- raise ValueError ("Padding mode should be either constant, edge or reflect " )
431+ if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
432+ raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric " )
405433
406434 if isinstance (padding , int ):
407435 if torch .jit .is_scripting ():
@@ -423,6 +451,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
423451 if padding_mode == "edge" :
424452 # remap padding_mode str
425453 padding_mode = "replicate"
454+ elif padding_mode == "symmetric" :
455+ # route to another implementation
456+ if p [0 ] < 0 or p [1 ] < 0 or p [2 ] < 0 or p [3 ] < 0 : # no any support for torch script
457+ raise ValueError ("Padding can not be negative for symmetric padding_mode" )
458+ return _pad_symmetric (img , p )
426459
427460 need_squeeze = False
428461 if img .ndim < 4 :
0 commit comments