From 5b87c50af200209836597d0c57b318e621b213ef Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 29 Jun 2020 21:40:18 +0200 Subject: [PATCH 1/4] [WIP] functional_tensor supports more padding modes --- test/test_functional_tensor.py | 18 +++++++++++------- torchvision/transforms/functional_tensor.py | 11 +++++++++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 07a699345bd..1d08d97aee9 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -248,17 +248,21 @@ def test_pad(self): script_fn = torch.jit.script(F_t.pad) tensor, pil_img = self._create_data(7, 8) for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]: - padding_mode = "constant" - for fill in [0, 10, 20]: - pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode) - pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode) - self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill)) + configs = [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + ] + for kwargs in configs: + pad_tensor = F_t.pad(tensor, pad, **kwargs) + pad_pil_img = F_pil.pad(pil_img, pad, **kwargs) + self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, kwargs)) if isinstance(pad, int): script_pad = [pad, ] else: script_pad = pad - pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode) - self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill)) + pad_tensor_script = script_fn(tensor, script_pad, **kwargs) + self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) if __name__ == '__main__': diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ccba4d08369..a6773854365 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -359,6 +359,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan - constant: pads with a constant value, this value is specified with fill + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + Returns: Tensor: Padded image. """ @@ -379,8 +386,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - if padding_mode not in ["constant", ]: - raise ValueError("Only constant padding_mode supported for torch tensors") + if padding_mode not in ["constant", "edge", "reflect"]: + raise ValueError("Padding mode should be either constant, edge, reflect") if isinstance(padding, int): if torch.jit.is_scripting(): From ebd84b886089af0c86c2da2e6e875eb0d303db1f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 30 Jun 2020 11:04:31 +0200 Subject: [PATCH 2/4] [WIP] Support all padding modes --- test/test_functional_tensor.py | 44 +++++++++++++-------- torchvision/transforms/functional_tensor.py | 33 ++++++++++++++-- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 1d08d97aee9..ae380e3b1ca 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -14,12 +14,15 @@ class Tester(unittest.TestCase): def _create_data(self, height=3, width=3, channels=3): - tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) + # tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) + tensor = torch.arange(0, channels * height * width, dtype=torch.uint8).reshape(channels, height, width) pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) return tensor, pil_img def compareTensorToPIL(self, tensor, pil_image, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + if pil_tensor.dtype != tensor.dtype: + pil_tensor = pil_tensor.to(tensor.dtype) self.assertTrue(tensor.equal(pil_tensor), msg) def test_vflip(self): @@ -247,22 +250,29 @@ def test_ten_crop(self): def test_pad(self): script_fn = torch.jit.script(F_t.pad) tensor, pil_img = self._create_data(7, 8) - for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]: - configs = [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, - ] - for kwargs in configs: - pad_tensor = F_t.pad(tensor, pad, **kwargs) - pad_pil_img = F_pil.pad(pil_img, pad, **kwargs) - self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, kwargs)) - if isinstance(pad, int): - script_pad = [pad, ] - else: - script_pad = pad - pad_tensor_script = script_fn(tensor, script_pad, **kwargs) - self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) + + for dt in [None, torch.float32, torch.float64]: + if dt is not None: + tensor = tensor.to(dt) + for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]: + configs = [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + # {"padding_mode": "symmetric"}, + ] + for kwargs in configs: + pad_tensor = F_t.pad(tensor, pad, **kwargs) + pad_pil_img = F_pil.pad(pil_img, pad, **kwargs) + self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, kwargs)) + if isinstance(pad, int): + script_pad = [pad, ] + else: + script_pad = pad + pad_tensor_script = script_fn(tensor, script_pad, **kwargs) + self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs)) if __name__ == '__main__': diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a6773854365..1b5369c8976 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -342,7 +342,7 @@ def _hsv2rgb(img): return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) -def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor: +def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: r"""Pad the given Tensor Image on all sides with specified padding mode and fill value. Args: @@ -366,6 +366,11 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + Returns: Tensor: Padded image. """ @@ -386,8 +391,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - if padding_mode not in ["constant", "edge", "reflect"]: - raise ValueError("Padding mode should be either constant, edge, reflect") + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") if isinstance(padding, int): if torch.jit.is_scripting(): @@ -406,5 +411,27 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan p = [pad_left, pad_right, pad_top, pad_bottom] + if padding_mode == "edge": + # remap padding_mode str + padding_mode = "replicate" + + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if img.dtype not in (torch.float32, torch.float64): + need_cast = True + img = img.to(torch.float32) + img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill)) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + img = img.to(out_dtype) + return img From 54b7f9243ef8b7902bf604fa0b8288140298ef08 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 30 Jun 2020 11:51:43 +0200 Subject: [PATCH 3/4] Removed wip symmetric mode --- test/test_functional_tensor.py | 4 +--- torchvision/transforms/functional_tensor.py | 12 +++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index ae380e3b1ca..f3b6aa6370c 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -14,8 +14,7 @@ class Tester(unittest.TestCase): def _create_data(self, height=3, width=3, channels=3): - # tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) - tensor = torch.arange(0, channels * height * width, dtype=torch.uint8).reshape(channels, height, width) + tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) return tensor, pil_img @@ -261,7 +260,6 @@ def test_pad(self): {"padding_mode": "constant", "fill": 20}, {"padding_mode": "edge"}, {"padding_mode": "reflect"}, - # {"padding_mode": "symmetric"}, ] for kwargs in configs: pad_tensor = F_t.pad(tensor, pad, **kwargs) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 1b5369c8976..9fa49e67af0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -366,11 +366,6 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] - - symmetric: pads with reflection of image (repeating the last value on the edge) - - padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode - will result in [2, 1, 1, 2, 3, 4, 4, 3] - Returns: Tensor: Padded image. """ @@ -391,8 +386,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + if padding_mode not in ["constant", "edge", "reflect"]: + raise ValueError("Padding mode should be either constant, edge or reflect") if isinstance(padding, int): if torch.jit.is_scripting(): @@ -423,6 +418,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con out_dtype = img.dtype need_cast = False if img.dtype not in (torch.float32, torch.float64): + # Here we temporary cast input tensor to float + # until pytorch issue is resolved : + # https://github.com/pytorch/pytorch/issues/40763 need_cast = True img = img.to(torch.float32) From 009c6a5c14eb42c257fb7a5ad9dd25d1a73cf69c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 30 Jun 2020 12:51:49 +0200 Subject: [PATCH 4/4] Improvements according to the review --- test/test_functional_tensor.py | 12 +++++++++--- torchvision/transforms/functional_tensor.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index f3b6aa6370c..0c2194e0f7b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -20,8 +20,6 @@ def _create_data(self, height=3, width=3, channels=3): def compareTensorToPIL(self, tensor, pil_image, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) - if pil_tensor.dtype != tensor.dtype: - pil_tensor = pil_tensor.to(tensor.dtype) self.assertTrue(tensor.equal(pil_tensor), msg) def test_vflip(self): @@ -252,6 +250,7 @@ def test_pad(self): for dt in [None, torch.float32, torch.float64]: if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]: configs = [ @@ -264,7 +263,14 @@ def test_pad(self): for kwargs in configs: pad_tensor = F_t.pad(tensor, pad, **kwargs) pad_pil_img = F_pil.pad(pil_img, pad, **kwargs) - self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, kwargs)) + + pad_tensor_8b = pad_tensor + # we need to cast to uint8 to compare with PIL image + if pad_tensor_8b.dtype != torch.uint8: + pad_tensor_8b = pad_tensor_8b.to(torch.uint8) + + self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs)) + if isinstance(pad, int): script_pad = [pad, ] else: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 9fa49e67af0..3b87dcc142a 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -417,7 +417,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con out_dtype = img.dtype need_cast = False - if img.dtype not in (torch.float32, torch.float64): + if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): # Here we temporary cast input tensor to float # until pytorch issue is resolved : # https://github.com/pytorch/pytorch/issues/40763