Skip to content

Commit 4106dbb

Browse files
tejank10vfdev-5
andauthored
Added GaussianBlur transform (#2658)
* Added GaussianBlur transform * fixed linting * supports fixed radius for kernel * [WIP] New API for gaussian_blur * Gaussian blur with kernelsize and sigma API * Fixed implementation and updated tests * Added large input case and refactored gt into a file * Updated docs * fix kernel dimesnions order while creating kernel * added tests for exception handling of gaussian blur * fix linting, bug in tests * Fixed failing tests, refactored code and other minor fixes Co-authored-by: vfdev-5 <[email protected]>
1 parent 87c7864 commit 4106dbb

9 files changed

+352
-20
lines changed

docs/source/transforms.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ Transforms on PIL Image
8181

8282
.. autoclass:: TenCrop
8383

84+
.. autoclass:: GaussianBlur
85+
8486
Transforms on torch.\*Tensor
8587
----------------------------
8688

48 KB
Binary file not shown.

test/test_functional_tensor.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23
import colorsys
34
import math
@@ -675,14 +676,14 @@ def test_rotate(self):
675676
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
676677
)
677678

678-
def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs):
679+
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
679680
dt = tensor.dtype
680681
for r in [0, ]:
681682
for spoints, epoints in test_configs:
682683
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
683684
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
684685

685-
for fn in [F.perspective, scripted_tranform]:
686+
for fn in [F.perspective, scripted_transform]:
686687
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
687688

688689
if out_tensor.dtype != torch.uint8:
@@ -707,7 +708,7 @@ def test_perspective(self):
707708
from torchvision.transforms import RandomPerspective
708709

709710
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
710-
scripted_tranform = torch.jit.script(F.perspective)
711+
scripted_transform = torch.jit.script(F.perspective)
711712

712713
for tensor, pil_img in data:
713714

@@ -730,7 +731,7 @@ def test_perspective(self):
730731
if dt is not None:
731732
tensor = tensor.to(dtype=dt)
732733

733-
self._test_perspective(tensor, pil_img, scripted_tranform, test_configs)
734+
self._test_perspective(tensor, pil_img, scripted_transform, test_configs)
734735

735736
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
736737
if dt is not None:
@@ -741,6 +742,70 @@ def test_perspective(self):
741742
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
742743
)
743744

745+
def test_gaussian_blur(self):
746+
small_image_tensor = torch.from_numpy(
747+
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
748+
).permute(2, 0, 1).to(self.device)
749+
750+
large_image_tensor = torch.from_numpy(
751+
np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
752+
).to(self.device)
753+
754+
scripted_transform = torch.jit.script(F.gaussian_blur)
755+
756+
# true_cv2_results = {
757+
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
758+
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
759+
# "3_3_0.8": ...
760+
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
761+
# "3_3_0.5": ...
762+
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
763+
# "3_5_0.8": ...
764+
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
765+
# "3_5_0.5": ...
766+
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
767+
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
768+
# "23_23_1.7": ...
769+
# }
770+
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
771+
true_cv2_results = torch.load(p)
772+
773+
for tensor in [small_image_tensor, large_image_tensor]:
774+
775+
for dt in [None, torch.float32, torch.float64, torch.float16]:
776+
if dt == torch.float16 and torch.device(self.device).type == "cpu":
777+
# skip float16 on CPU case
778+
continue
779+
780+
if dt is not None:
781+
tensor = tensor.to(dtype=dt)
782+
783+
for ksize in [(3, 3), [3, 5], (23, 23)]:
784+
for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:
785+
786+
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
787+
_sigma = sigma[0] if sigma is not None else None
788+
shape = tensor.shape
789+
gt_key = "{}_{}_{}__{}_{}_{}".format(
790+
shape[-2], shape[-1], shape[-3],
791+
_ksize[0], _ksize[1], _sigma
792+
)
793+
if gt_key not in true_cv2_results:
794+
continue
795+
796+
true_out = torch.tensor(
797+
true_cv2_results[gt_key]
798+
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
799+
800+
for fn in [F.gaussian_blur, scripted_transform]:
801+
out = fn(tensor, kernel_size=ksize, sigma=sigma)
802+
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
803+
self.assertLessEqual(
804+
torch.max(true_out.float() - out.float()),
805+
1.0,
806+
msg="{}, {}".format(ksize, sigma)
807+
)
808+
744809

745810
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
746811
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,48 @@ def test_random_grayscale(self):
16541654
# Checking if RandomGrayscale can be printed as string
16551655
trans3.__repr__()
16561656

1657+
def test_gaussian_blur_asserts(self):
1658+
np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
1659+
img = F.to_pil_image(np_img, "RGB")
1660+
1661+
with self.assertRaisesRegex(ValueError, r"If kernel_size is a sequence its length should be 2"):
1662+
F.gaussian_blur(img, [3])
1663+
1664+
with self.assertRaisesRegex(ValueError, r"If kernel_size is a sequence its length should be 2"):
1665+
F.gaussian_blur(img, [3, 3, 3])
1666+
with self.assertRaisesRegex(ValueError, r"Kernel size should be a tuple/list of two integers"):
1667+
transforms.GaussianBlur([3, 3, 3])
1668+
1669+
with self.assertRaisesRegex(ValueError, r"kernel_size should have odd and positive integers"):
1670+
F.gaussian_blur(img, [4, 4])
1671+
with self.assertRaisesRegex(ValueError, r"Kernel size value should be an odd and positive number"):
1672+
transforms.GaussianBlur([4, 4])
1673+
1674+
with self.assertRaisesRegex(ValueError, r"kernel_size should have odd and positive integers"):
1675+
F.gaussian_blur(img, [-3, -3])
1676+
with self.assertRaisesRegex(ValueError, r"Kernel size value should be an odd and positive number"):
1677+
transforms.GaussianBlur([-3, -3])
1678+
1679+
with self.assertRaisesRegex(ValueError, r"If sigma is a sequence, its length should be 2"):
1680+
F.gaussian_blur(img, 3, [1, 1, 1])
1681+
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
1682+
transforms.GaussianBlur(3, [1, 1, 1])
1683+
1684+
with self.assertRaisesRegex(ValueError, r"sigma should have positive values"):
1685+
F.gaussian_blur(img, 3, -1.0)
1686+
with self.assertRaisesRegex(ValueError, r"If sigma is a single number, it must be positive"):
1687+
transforms.GaussianBlur(3, -1.0)
1688+
1689+
with self.assertRaisesRegex(TypeError, r"kernel_size should be int or a sequence of integers"):
1690+
F.gaussian_blur(img, "kernel_size_string")
1691+
with self.assertRaisesRegex(ValueError, r"Kernel size should be a tuple/list of two integers"):
1692+
transforms.GaussianBlur("kernel_size_string")
1693+
1694+
with self.assertRaisesRegex(TypeError, r"sigma should be either float or sequence of floats"):
1695+
F.gaussian_blur(img, 3, "sigma_string")
1696+
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
1697+
transforms.GaussianBlur(3, "sigma_string")
1698+
16571699

16581700
if __name__ == '__main__':
16591701
unittest.main()

test/test_transforms_tensor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,38 @@ def test_compose(self):
466466
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
467467
torch.jit.script(t)
468468

469+
def test_gaussian_blur(self):
470+
tol = 1.0 + 1e-10
471+
self._test_class_op(
472+
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
473+
test_exact_match=False, agg_method="max", tol=tol
474+
)
475+
476+
self._test_class_op(
477+
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
478+
test_exact_match=False, agg_method="max", tol=tol
479+
)
480+
481+
self._test_class_op(
482+
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
483+
test_exact_match=False, agg_method="max", tol=tol
484+
)
485+
486+
self._test_class_op(
487+
"GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
488+
test_exact_match=False, agg_method="max", tol=tol
489+
)
490+
491+
self._test_class_op(
492+
"GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
493+
test_exact_match=False, agg_method="max", tol=tol
494+
)
495+
496+
self._test_class_op(
497+
"GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75},
498+
test_exact_match=False, agg_method="max", tol=tol
499+
)
500+
469501
def test_random_erasing(self):
470502
img = torch.rand(3, 60, 60)
471503

torchvision/transforms/functional.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def pil_to_tensor(pic):
115115
Returns:
116116
Tensor: Converted image.
117117
"""
118-
if not(F_pil._is_pil_image(pic)):
118+
if not F_pil._is_pil_image(pic):
119119
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
120120

121121
if accimage is not None and isinstance(pic, accimage.Image):
@@ -297,7 +297,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) ->
297297
the smaller edge of the image will be matched to this number maintaining
298298
the aspect ratio. i.e, if height > width, then image will be rescaled to
299299
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
300-
In torchscript mode padding as single int is not supported, use a tuple or
300+
In torchscript mode size as single int is not supported, use a tuple or
301301
list of length 1: ``[size, ]``.
302302
interpolation (int, optional): Desired interpolation enum defined by `filters`_.
303303
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
@@ -988,3 +988,63 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
988988

989989
img[..., i:i + h, j:j + w] = v
990990
return img
991+
992+
993+
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
994+
"""Performs Gaussian blurring on the img by given kernel.
995+
The image can be a PIL Image or a Tensor, in which case it is expected
996+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
997+
998+
Args:
999+
img (PIL Image or Tensor): Image to be blurred
1000+
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
1001+
like ``(kx, ky)`` or a single integer for square kernels.
1002+
In torchscript mode kernel_size as single int is not supported, use a tuple or
1003+
list of length 1: ``[ksize, ]``.
1004+
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
1005+
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
1006+
same sigma in both X/Y directions. If None, then it is computed using
1007+
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
1008+
Default, None. In torchscript mode sigma as single float is
1009+
not supported, use a tuple or list of length 1: ``[sigma, ]``.
1010+
1011+
Returns:
1012+
PIL Image or Tensor: Gaussian Blurred version of the image.
1013+
"""
1014+
if not isinstance(kernel_size, (int, list, tuple)):
1015+
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
1016+
if isinstance(kernel_size, int):
1017+
kernel_size = [kernel_size, kernel_size]
1018+
if len(kernel_size) != 2:
1019+
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
1020+
for ksize in kernel_size:
1021+
if ksize % 2 == 0 or ksize < 0:
1022+
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
1023+
1024+
if sigma is None:
1025+
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
1026+
1027+
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
1028+
raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
1029+
if isinstance(sigma, (int, float)):
1030+
sigma = [float(sigma), float(sigma)]
1031+
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
1032+
sigma = [sigma[0], sigma[0]]
1033+
if len(sigma) != 2:
1034+
raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
1035+
for s in sigma:
1036+
if s <= 0.:
1037+
raise ValueError('sigma should have positive values. Got {}'.format(sigma))
1038+
1039+
t_img = img
1040+
if not isinstance(img, torch.Tensor):
1041+
if not F_pil._is_pil_image(img):
1042+
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
1043+
1044+
t_img = to_tensor(img)
1045+
1046+
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
1047+
1048+
if not isinstance(img, torch.Tensor):
1049+
output = to_pil_image(output)
1050+
return output

torchvision/transforms/functional_pil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55
import torch
6-
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
6+
from PIL import Image, ImageOps, ImageEnhance, ImageFilter, __version__ as PILLOW_VERSION
77

88
try:
99
import accimage

0 commit comments

Comments
 (0)