Skip to content

Commit 39e4057

Browse files
authored
Added symmetric padding mode for Tensors (#2373)
* [WIP] Added symmetric padding mode * Added check and raise error if padding is negative for symmetric padding mode * Added test check for raising error if negative pad
1 parent e4b9823 commit 39e4057

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

test/test_functional_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def test_pad(self):
259259
{"padding_mode": "constant", "fill": 20},
260260
{"padding_mode": "edge"},
261261
{"padding_mode": "reflect"},
262+
{"padding_mode": "symmetric"},
262263
]
263264
for kwargs in configs:
264265
pad_tensor = F_t.pad(tensor, pad, **kwargs)
@@ -278,6 +279,9 @@ def test_pad(self):
278279
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
279280
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
280281

282+
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
283+
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
284+
281285

282286
if __name__ == '__main__':
283287
unittest.main()

torchvision/transforms/functional_tensor.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,29 @@ def _hsv2rgb(img):
355355
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
356356

357357

358+
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
359+
# padding is left, right, top, bottom
360+
in_sizes = img.size()
361+
362+
x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
363+
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
364+
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
365+
x_indices = torch.tensor(left_indices + x_indices + right_indices)
366+
367+
y_indices = [i for i in range(in_sizes[-2])]
368+
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
369+
bottom_indices = [-(i + 1) for i in range(padding[3])]
370+
y_indices = torch.tensor(top_indices + y_indices + bottom_indices)
371+
372+
ndim = img.ndim
373+
if ndim == 3:
374+
return img[:, y_indices[:, None], x_indices[None, :]]
375+
elif ndim == 4:
376+
return img[:, :, y_indices[:, None], x_indices[None, :]]
377+
else:
378+
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
379+
380+
358381
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
359382
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
360383
@@ -380,6 +403,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
380403
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
381404
will result in [3, 2, 1, 2, 3, 4, 3, 2]
382405
406+
- symmetric: pads with reflection of image (repeating the last value on the edge)
407+
408+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
409+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
410+
383411
Returns:
384412
Tensor: Padded image.
385413
"""
@@ -400,8 +428,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
400428
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
401429
"{} element tuple".format(len(padding)))
402430

403-
if padding_mode not in ["constant", "edge", "reflect"]:
404-
raise ValueError("Padding mode should be either constant, edge or reflect")
431+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
432+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
405433

406434
if isinstance(padding, int):
407435
if torch.jit.is_scripting():
@@ -423,6 +451,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
423451
if padding_mode == "edge":
424452
# remap padding_mode str
425453
padding_mode = "replicate"
454+
elif padding_mode == "symmetric":
455+
# route to another implementation
456+
if p[0] < 0 or p[1] < 0 or p[2] < 0 or p[3] < 0: # no any support for torch script
457+
raise ValueError("Padding can not be negative for symmetric padding_mode")
458+
return _pad_symmetric(img, p)
426459

427460
need_squeeze = False
428461
if img.ndim < 4:

0 commit comments

Comments
 (0)