Skip to content

Issue 2350 - support of all padding modes with tensors #2368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,36 @@ def test_ten_crop(self):
def test_pad(self):
script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8)
for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]:
padding_mode = "constant"
for fill in [0, 10, 20]:
pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode)
pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode)
self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill))
if isinstance(pad, int):
script_pad = [pad, ]
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill))

for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
configs = [
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
]
for kwargs in configs:
pad_tensor = F_t.pad(tensor, pad, **kwargs)
pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)

pad_tensor_8b = pad_tensor
# we need to cast to uint8 to compare with PIL image
if pad_tensor_8b.dtype != torch.uint8:
pad_tensor_8b = pad_tensor_8b.to(torch.uint8)

self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))

if isinstance(pad, int):
script_pad = [pad, ]
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))


if __name__ == '__main__':
Expand Down
38 changes: 35 additions & 3 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def _hsv2rgb(img):
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)


def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor:
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.

Args:
Expand All @@ -359,6 +359,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan

- constant: pads with a constant value, this value is specified with fill

- edge: pads with the last value on the edge of the image

- reflect: pads with reflection of image (without repeating the last value on the edge)

padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]

Returns:
Tensor: Padded image.
"""
Expand All @@ -379,8 +386,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

if padding_mode not in ["constant", ]:
raise ValueError("Only constant padding_mode supported for torch tensors")
if padding_mode not in ["constant", "edge", "reflect"]:
raise ValueError("Padding mode should be either constant, edge or reflect")

if isinstance(padding, int):
if torch.jit.is_scripting():
Expand All @@ -399,5 +406,30 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan

p = [pad_left, pad_right, pad_top, pad_bottom]

if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"

need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True

out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)

img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill))

if need_squeeze:
img = img.squeeze(dim=0)

if need_cast:
img = img.to(out_dtype)

return img