@@ -355,6 +355,29 @@ def _hsv2rgb(img):
355
355
return torch .einsum ("ijk, xijk -> xjk" , mask .to (dtype = img .dtype ), a4 )
356
356
357
357
358
+ def _pad_symmetric (img : Tensor , padding : List [int ]) -> Tensor :
359
+ # padding is left, right, top, bottom
360
+ in_sizes = img .size ()
361
+
362
+ x_indices = [i for i in range (in_sizes [- 1 ])] # [0, 1, 2, 3, ...]
363
+ left_indices = [i for i in range (padding [0 ] - 1 , - 1 , - 1 )] # e.g. [3, 2, 1, 0]
364
+ right_indices = [- (i + 1 ) for i in range (padding [1 ])] # e.g. [-1, -2, -3]
365
+ x_indices = torch .tensor (left_indices + x_indices + right_indices )
366
+
367
+ y_indices = [i for i in range (in_sizes [- 2 ])]
368
+ top_indices = [i for i in range (padding [2 ] - 1 , - 1 , - 1 )]
369
+ bottom_indices = [- (i + 1 ) for i in range (padding [3 ])]
370
+ y_indices = torch .tensor (top_indices + y_indices + bottom_indices )
371
+
372
+ ndim = img .ndim
373
+ if ndim == 3 :
374
+ return img [:, y_indices [:, None ], x_indices [None , :]]
375
+ elif ndim == 4 :
376
+ return img [:, :, y_indices [:, None ], x_indices [None , :]]
377
+ else :
378
+ raise RuntimeError ("Symmetric padding of N-D tensors are not supported yet" )
379
+
380
+
358
381
def pad (img : Tensor , padding : List [int ], fill : int = 0 , padding_mode : str = "constant" ) -> Tensor :
359
382
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
360
383
@@ -380,6 +403,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
380
403
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
381
404
will result in [3, 2, 1, 2, 3, 4, 3, 2]
382
405
406
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
407
+
408
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
409
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
410
+
383
411
Returns:
384
412
Tensor: Padded image.
385
413
"""
@@ -400,8 +428,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
400
428
raise ValueError ("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
401
429
"{} element tuple" .format (len (padding )))
402
430
403
- if padding_mode not in ["constant" , "edge" , "reflect" ]:
404
- raise ValueError ("Padding mode should be either constant, edge or reflect " )
431
+ if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
432
+ raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric " )
405
433
406
434
if isinstance (padding , int ):
407
435
if torch .jit .is_scripting ():
@@ -423,6 +451,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
423
451
if padding_mode == "edge" :
424
452
# remap padding_mode str
425
453
padding_mode = "replicate"
454
+ elif padding_mode == "symmetric" :
455
+ # route to another implementation
456
+ if p [0 ] < 0 or p [1 ] < 0 or p [2 ] < 0 or p [3 ] < 0 : # no any support for torch script
457
+ raise ValueError ("Padding can not be negative for symmetric padding_mode" )
458
+ return _pad_symmetric (img , p )
426
459
427
460
need_squeeze = False
428
461
if img .ndim < 4 :
0 commit comments