Skip to content

Commit 6fe11d5

Browse files
authored
Issue 2350 - support of all padding modes with tensors (#2368)
* [WIP] functional_tensor supports more padding modes * [WIP] Support all padding modes * Removed wip symmetric mode * Improvements according to the review
1 parent a99b6bd commit 6fe11d5

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

test/test_functional_tensor.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,36 @@ def test_ten_crop(self):
247247
def test_pad(self):
248248
script_fn = torch.jit.script(F_t.pad)
249249
tensor, pil_img = self._create_data(7, 8)
250-
for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]:
251-
padding_mode = "constant"
252-
for fill in [0, 10, 20]:
253-
pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode)
254-
pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode)
255-
self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill))
256-
if isinstance(pad, int):
257-
script_pad = [pad, ]
258-
else:
259-
script_pad = pad
260-
pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode)
261-
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill))
250+
251+
for dt in [None, torch.float32, torch.float64]:
252+
if dt is not None:
253+
# This is a trivial cast to float of uint8 data to test all cases
254+
tensor = tensor.to(dt)
255+
for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
256+
configs = [
257+
{"padding_mode": "constant", "fill": 0},
258+
{"padding_mode": "constant", "fill": 10},
259+
{"padding_mode": "constant", "fill": 20},
260+
{"padding_mode": "edge"},
261+
{"padding_mode": "reflect"},
262+
]
263+
for kwargs in configs:
264+
pad_tensor = F_t.pad(tensor, pad, **kwargs)
265+
pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)
266+
267+
pad_tensor_8b = pad_tensor
268+
# we need to cast to uint8 to compare with PIL image
269+
if pad_tensor_8b.dtype != torch.uint8:
270+
pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
271+
272+
self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))
273+
274+
if isinstance(pad, int):
275+
script_pad = [pad, ]
276+
else:
277+
script_pad = pad
278+
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
279+
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
262280

263281

264282
if __name__ == '__main__':

torchvision/transforms/functional_tensor.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def _hsv2rgb(img):
346346
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
347347

348348

349-
def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor:
349+
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
350350
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
351351
352352
Args:
@@ -363,6 +363,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
363363
364364
- constant: pads with a constant value, this value is specified with fill
365365
366+
- edge: pads with the last value on the edge of the image
367+
368+
- reflect: pads with reflection of image (without repeating the last value on the edge)
369+
370+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
371+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
372+
366373
Returns:
367374
Tensor: Padded image.
368375
"""
@@ -383,8 +390,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
383390
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
384391
"{} element tuple".format(len(padding)))
385392

386-
if padding_mode not in ["constant", ]:
387-
raise ValueError("Only constant padding_mode supported for torch tensors")
393+
if padding_mode not in ["constant", "edge", "reflect"]:
394+
raise ValueError("Padding mode should be either constant, edge or reflect")
388395

389396
if isinstance(padding, int):
390397
if torch.jit.is_scripting():
@@ -403,5 +410,30 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
403410

404411
p = [pad_left, pad_right, pad_top, pad_bottom]
405412

413+
if padding_mode == "edge":
414+
# remap padding_mode str
415+
padding_mode = "replicate"
416+
417+
need_squeeze = False
418+
if img.ndim < 4:
419+
img = img.unsqueeze(dim=0)
420+
need_squeeze = True
421+
422+
out_dtype = img.dtype
423+
need_cast = False
424+
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
425+
# Here we temporary cast input tensor to float
426+
# until pytorch issue is resolved :
427+
# https://github.com/pytorch/pytorch/issues/40763
428+
need_cast = True
429+
img = img.to(torch.float32)
430+
406431
img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill))
432+
433+
if need_squeeze:
434+
img = img.squeeze(dim=0)
435+
436+
if need_cast:
437+
img = img.to(out_dtype)
438+
407439
return img

0 commit comments

Comments
 (0)