Skip to content

Added GaussianBlur transform #2658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ Transforms on PIL Image

.. autoclass:: TenCrop

.. autoclass:: GaussianBlur

Transforms on torch.\*Tensor
----------------------------

Expand Down
Binary file added test/assets/gaussian_blur_opencv_results.pt
Binary file not shown.
73 changes: 69 additions & 4 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
import colorsys
import math
Expand Down Expand Up @@ -675,14 +676,14 @@ def test_rotate(self):
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
)

def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs):
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
dt = tensor.dtype
for r in [0, ]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

for fn in [F.perspective, scripted_tranform]:
for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()

if out_tensor.dtype != torch.uint8:
Expand All @@ -707,7 +708,7 @@ def test_perspective(self):
from torchvision.transforms import RandomPerspective

data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
scripted_tranform = torch.jit.script(F.perspective)
scripted_transform = torch.jit.script(F.perspective)

for tensor, pil_img in data:

Expand All @@ -730,7 +731,7 @@ def test_perspective(self):
if dt is not None:
tensor = tensor.to(dtype=dt)

self._test_perspective(tensor, pil_img, scripted_tranform, test_configs)
self._test_perspective(tensor, pil_img, scripted_transform, test_configs)

batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
if dt is not None:
Expand All @@ -741,6 +742,70 @@ def test_perspective(self):
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
)

def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy(
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
).permute(2, 0, 1).to(self.device)

large_image_tensor = torch.from_numpy(
np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
).to(self.device)

scripted_transform = torch.jit.script(F.gaussian_blur)

# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
true_cv2_results = torch.load(p)

for tensor in [small_image_tensor, large_image_tensor]:

for dt in [None, torch.float32, torch.float64, torch.float16]:
if dt == torch.float16 and torch.device(self.device).type == "cpu":
# skip float16 on CPU case
continue

if dt is not None:
tensor = tensor.to(dtype=dt)

for ksize in [(3, 3), [3, 5], (23, 23)]:
for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:

_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None
shape = tensor.shape
gt_key = "{}_{}_{}__{}_{}_{}".format(
shape[-2], shape[-1], shape[-3],
_ksize[0], _ksize[1], _sigma
)
if gt_key not in true_cv2_results:
continue

true_out = torch.tensor(
true_cv2_results[gt_key]
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

for fn in [F.gaussian_blur, scripted_transform]:
out = fn(tensor, kernel_size=ksize, sigma=sigma)
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
self.assertLessEqual(
torch.max(true_out.float() - out.float()),
1.0,
msg="{}, {}".format(ksize, sigma)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
42 changes: 42 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,48 @@ def test_random_grayscale(self):
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()

def test_gaussian_blur_asserts(self):
np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
img = F.to_pil_image(np_img, "RGB")

with self.assertRaisesRegex(ValueError, r"If kernel_size is a sequence its length should be 2"):
F.gaussian_blur(img, [3])

with self.assertRaisesRegex(ValueError, r"If kernel_size is a sequence its length should be 2"):
F.gaussian_blur(img, [3, 3, 3])
with self.assertRaisesRegex(ValueError, r"Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur([3, 3, 3])

with self.assertRaisesRegex(ValueError, r"kernel_size should have odd and positive integers"):
F.gaussian_blur(img, [4, 4])
with self.assertRaisesRegex(ValueError, r"Kernel size value should be an odd and positive number"):
transforms.GaussianBlur([4, 4])

with self.assertRaisesRegex(ValueError, r"kernel_size should have odd and positive integers"):
F.gaussian_blur(img, [-3, -3])
with self.assertRaisesRegex(ValueError, r"Kernel size value should be an odd and positive number"):
transforms.GaussianBlur([-3, -3])

with self.assertRaisesRegex(ValueError, r"If sigma is a sequence, its length should be 2"):
F.gaussian_blur(img, 3, [1, 1, 1])
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, [1, 1, 1])

with self.assertRaisesRegex(ValueError, r"sigma should have positive values"):
F.gaussian_blur(img, 3, -1.0)
with self.assertRaisesRegex(ValueError, r"If sigma is a single number, it must be positive"):
transforms.GaussianBlur(3, -1.0)

with self.assertRaisesRegex(TypeError, r"kernel_size should be int or a sequence of integers"):
F.gaussian_blur(img, "kernel_size_string")
with self.assertRaisesRegex(ValueError, r"Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur("kernel_size_string")

with self.assertRaisesRegex(TypeError, r"sigma should be either float or sequence of floats"):
F.gaussian_blur(img, 3, "sigma_string")
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string")


if __name__ == '__main__':
unittest.main()
32 changes: 32 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,38 @@ def test_compose(self):
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)

def test_gaussian_blur(self):
tol = 1.0 + 1e-10
self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
)

self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
test_exact_match=False, agg_method="max", tol=tol
)

self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
)

self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
test_exact_match=False, agg_method="max", tol=tol
)

self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
test_exact_match=False, agg_method="max", tol=tol
)

self._test_class_op(
"GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75},
test_exact_match=False, agg_method="max", tol=tol
)

def test_random_erasing(self):
img = torch.rand(3, 60, 60)

Expand Down
64 changes: 62 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def pil_to_tensor(pic):
Returns:
Tensor: Converted image.
"""
if not(F_pil._is_pil_image(pic)):
if not F_pil._is_pil_image(pic):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))

if accimage is not None and isinstance(pic, accimage.Image):
Expand Down Expand Up @@ -297,7 +297,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) ->
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as single int is not supported, use a tuple or
In torchscript mode size as single int is not supported, use a tuple or
list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation enum defined by `filters`_.
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
Expand Down Expand Up @@ -988,3 +988,63 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool

img[..., i:i + h, j:j + w] = v
return img


def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
"""Performs Gaussian blurring on the img by given kernel.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

Args:
img (PIL Image or Tensor): Image to be blurred
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
like ``(kx, ky)`` or a single integer for square kernels.
In torchscript mode kernel_size as single int is not supported, use a tuple or
list of length 1: ``[ksize, ]``.
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
same sigma in both X/Y directions. If None, then it is computed using
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
Default, None. In torchscript mode sigma as single float is
not supported, use a tuple or list of length 1: ``[sigma, ]``.

Returns:
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))

if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
for s in sigma:
if s <= 0.:
raise ValueError('sigma should have positive values. Got {}'.format(sigma))

t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))

t_img = to_tensor(img)

output = F_t.gaussian_blur(t_img, kernel_size, sigma)

if not isinstance(img, torch.Tensor):
output = to_pil_image(output)
return output
2 changes: 1 addition & 1 deletion torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import torch
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
from PIL import Image, ImageOps, ImageEnhance, ImageFilter, __version__ as PILLOW_VERSION

try:
import accimage
Expand Down
Loading