diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 07a699345bd..0c2194e0f7b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -247,18 +247,36 @@ def test_ten_crop(self): def test_pad(self): script_fn = torch.jit.script(F_t.pad) tensor, pil_img = self._create_data(7, 8) - for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]: - padding_mode = "constant" - for fill in [0, 10, 20]: - pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode) - pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode) - self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill)) - if isinstance(pad, int): - script_pad = [pad, ] - else: - script_pad = pad - pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode) - self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill)) + + for dt in [None, torch.float32, torch.float64]: + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]: + configs = [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + ] + for kwargs in configs: + pad_tensor = F_t.pad(tensor, pad, **kwargs) + pad_pil_img = F_pil.pad(pil_img, pad, **kwargs) + + pad_tensor_8b = pad_tensor + # we need to cast to uint8 to compare with PIL image + if pad_tensor_8b.dtype != torch.uint8: + pad_tensor_8b = pad_tensor_8b.to(torch.uint8) + + self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs)) + + if isinstance(pad, int): + script_pad = [pad, ] + else: + script_pad = pad + pad_tensor_script = script_fn(tensor, script_pad, **kwargs) + self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) if __name__ == '__main__': diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ccba4d08369..3b87dcc142a 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -342,7 +342,7 @@ def _hsv2rgb(img): return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) -def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor: +def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: r"""Pad the given Tensor Image on all sides with specified padding mode and fill value. Args: @@ -359,6 +359,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan - constant: pads with a constant value, this value is specified with fill + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + Returns: Tensor: Padded image. """ @@ -379,8 +386,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - if padding_mode not in ["constant", ]: - raise ValueError("Only constant padding_mode supported for torch tensors") + if padding_mode not in ["constant", "edge", "reflect"]: + raise ValueError("Padding mode should be either constant, edge or reflect") if isinstance(padding, int): if torch.jit.is_scripting(): @@ -399,5 +406,30 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan p = [pad_left, pad_right, pad_top, pad_bottom] + if padding_mode == "edge": + # remap padding_mode str + padding_mode = "replicate" + + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): + # Here we temporary cast input tensor to float + # until pytorch issue is resolved : + # https://github.com/pytorch/pytorch/issues/40763 + need_cast = True + img = img.to(torch.float32) + img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill)) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + img = img.to(out_dtype) + return img