Skip to content

Commit 38f3c9c

Browse files
committed
Gaussian blur with kernelsize and sigma API
1 parent d2ccb49 commit 38f3c9c

File tree

3 files changed

+79
-58
lines changed

3 files changed

+79
-58
lines changed

torchvision/transforms/functional.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,18 +1028,18 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
10281028
return img
10291029

10301030

1031-
def gaussian_blur(img: Tensor, kernel_size: int, sigma: float = None) -> Tensor:
1031+
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
10321032
"""Performs Gaussian blurring on the img by given kernel.
10331033
The image can be a PIL Image or a Tensor, in which case it is expected
10341034
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
10351035
10361036
Args:
10371037
img (PIL Image or Tensor): Image to be blurred
1038-
kernel_size (sequence or int): Gaussian kernel size. Can be a sequence of integers
1038+
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
10391039
like ``(kx, ky)`` or a single integer for square kernels.
10401040
In torchscript mode kernel_size as single int is not supported, use a tuple or
10411041
list of length 1: ``[size, ]``.
1042-
sigma (sequence or float, optional): Gaussian kernel standard deviation. Can be a
1042+
sigma (sequence of floats or float or None, optional): Gaussian kernel standard deviation. Can be a
10431043
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
10441044
same sigma in both X/Y directions. If None, then it is computed using
10451045
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
@@ -1056,12 +1056,10 @@ def gaussian_blur(img: Tensor, kernel_size: int, sigma: float = None) -> Tensor:
10561056
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
10571057

10581058
is_pil_image = True
1059-
t_img = pil_to_tensor(img)
1059+
t_img = to_tensor(img)
10601060

10611061
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
10621062

10631063
if is_pil_image:
1064-
output = output.permute((1, 2, 0))
1065-
output = Image.fromarray(output.numpy())
1066-
1064+
output = to_pil_image(output)
10671065
return output

torchvision/transforms/functional_tensor.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -942,65 +942,78 @@ def perspective(
942942
return _apply_grid_transform(img, grid, mode)
943943

944944

945-
def _get_kernel(radius: float, passes: int):
946-
sigma2 = torch.Tensor([radius ** 2 / passes])
945+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float):
946+
ksize_half = (kernel_size - 1) * 0.5
947947

948-
kernel_rad = (torch.sqrt(12. * sigma2 + 1.) - 1.) / 2.
948+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
949+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
950+
kernel1d = pdf / pdf.sum()
949951

950-
kernel_rad_int = kernel_rad.long().item()
952+
return kernel1d
951953

952-
kernel_rad_float = (2 * kernel_rad_int + 1) * (kernel_rad_int * (kernel_rad_int + 1) - 3 * sigma2)
953-
kernel_rad_float /= 6 * (sigma2 - (kernel_rad_int + 1) * (kernel_rad_int + 1))
954-
kernel_rad_float = kernel_rad_float.item()
955954

956-
kernel_rad = kernel_rad_int + kernel_rad_float
955+
def _get_gaussian_kernel2d(kernel_size: List[int], sigma: List[float]):
956+
ksize_x, ksize_y = kernel_size
957+
sigma_x, sigma_y = sigma
957958

958-
ksize = 2 * kernel_rad_int + 1 + 2 * (kernel_rad_float > 0)
959-
kernel1d = torch.ones(ksize) / (2 * kernel_rad + 1)
959+
kernel1d_x = _get_gaussian_kernel1d(ksize_x, sigma_x)
960+
kernel1d_y = _get_gaussian_kernel1d(ksize_y, sigma_y)
960961

961-
if kernel_rad_float > 0:
962-
kernel1d[[0, -1]] = kernel_rad_float / (2 * kernel_rad + 1)
963-
964-
kernel2d = torch.mm(kernel1d[:, None], kernel1d[None, :])
962+
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
965963

966964
return kernel2d
967965

968966

969-
def gaussian_blur(img: Tensor, radius: float) -> Tensor:
967+
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
970968
"""Performs Gaussian blurring on the img by given kernel.
971969
972970
Args:
973971
img (Tensor): Image to be blurred
974-
radius (float): Blur radius
972+
kernel_size (sequence of int or int): Kernel size of the Gaussian kernel
973+
sigma (sequence of float or float or None): Standard deviation of the Gaussian kernel
975974
976975
Returns:
977-
Tensor: An image that is blurred using kernel of given radius
976+
Tensor: An image that is blurred using gaussian kernel of given parameters
978977
"""
979-
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
978+
if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
980979
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
981-
if not isinstance(radius, (float, int)):
982-
raise TypeError('radius should be either float or int. Got {}'.format(type(radius)))
980+
if not isinstance(kernel_size, (int, list, tuple)):
981+
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
982+
if not isinstance(sigma, (float, int, list, tuple)) and sigma != None:
983+
raise TypeError('sigma should be either float or int or its sequence. Got {}'.format(type(sigma)))
984+
985+
if isinstance(kernel_size, int):
986+
kernel_size = [kernel_size] * 2
987+
if isinstance(sigma, (int, float, None)):
988+
sigma = [sigma] * 2
989+
990+
if len(kernel_size) != 2:
991+
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
992+
if len(sigma) != 2:
993+
raise ValueError('If sigma is a sequence its length should be 2. Got {}'.format(len(sigma)))
994+
995+
if any([ksize % 2 == 0 or not isinstance(ksize, int) for ksize in kernel_size]):
996+
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
997+
998+
sigma = [s if s != None else 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8 for ksize, s in zip(kernel_size, sigma)]
983999

984-
radius = float(radius)
985-
passes = 3
1000+
if any([s <= 0. for s in sigma]):
1001+
raise ValueError('sigma should have positive values. Got {}'.format(sigma))
9861002

9871003
ndim = img.ndim
9881004
if ndim == 2:
9891005
img = img.unsqueeze(0)
9901006
if ndim == 3:
9911007
img = img.unsqueeze(0)
9921008

993-
kernel = _get_kernel(radius, passes)
1009+
kernel = _get_gaussian_kernel2d(kernel_size, sigma)
9941010

995-
padding = _compute_padding(kernel.shape[::-1])
1011+
padding = _compute_padding(kernel_size)
9961012

9971013
kernel = kernel[None, None, :, :].repeat(img.size(-3), 1, 1, 1)
9981014

999-
padded_img = pad(img, padding, padding_mode='edge')
1015+
padded_img = pad(img, padding, padding_mode='reflect')
10001016
blurred_img = conv2d(padded_img, kernel, groups=img.size(-3))
1001-
for _ in range(passes - 1):
1002-
padded_img = pad(blurred_img, padding, padding_mode='edge')
1003-
blurred_img = conv2d(padded_img, kernel, groups=img.size(-3))
10041017

10051018
if ndim == 2:
10061019
return blurred_img[0, 0]

torchvision/transforms/transforms.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,44 +1553,52 @@ class GaussianBlur(torch.nn.Module):
15531553
dimensions
15541554
15551555
Args:
1556-
radius (float or tuple of float (min, max)): Radius to be used for creating
1557-
kernel to perform blurring. If float, radius is fixed. If it is tuple of
1558-
float (min, max), kernel radius is chosen uniformly at random to lie in the
1556+
ksize (int): Size of the Gaussian kernel.
1557+
sigma (float or tuple of float (min, max)): Standard deviation to be used for
1558+
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
1559+
of float (min, max), sigma is chosen uniformly at random to lie in the
15591560
given range.
15601561
15611562
Returns:
15621563
PIL Image or Tensor: Gaussian blurred version of the input image.
15631564
15641565
"""
15651566

1566-
def __init__(self, radius=(0.1, 2.0)):
1567+
def __init__(self, ksize, sigma=(0.1, 2.0)):
15671568
super().__init__()
15681569

1569-
if isinstance(radius, numbers.Number):
1570-
if radius <= 0:
1571-
raise ValueError("If radius is a single number, it must be positive.")
1572-
radius = (radius, radius)
1573-
elif isinstance(radius, (tuple, list)) and len(radius) == 2:
1574-
if not 0. < radius[0] <= radius[1]:
1575-
raise ValueError("radius values should be positive and of the form (min, max)")
1570+
if isinstance(ksize, numbers.Number):
1571+
if ksize <= 0 or ksize % 2 == 0:
1572+
raise ValueError("ksize should be an odd and positive number.")
15761573
else:
1577-
raise TypeError("radius should be a single number or a list/tuple with length 2.")
1574+
raise TypeError("ksize should be a single number.")
1575+
1576+
if isinstance(sigma, numbers.Number):
1577+
if sigma <= 0:
1578+
raise ValueError("If sigma is a single number, it must be positive.")
1579+
sigma = (sigma, sigma)
1580+
elif isinstance(sigma, (tuple, list)) and len(sigma) == 2:
1581+
if not 0. < sigma[0] <= sigma[1]:
1582+
raise ValueError("sigma values should be positive and of the form (min, max).")
1583+
else:
1584+
raise TypeError("sigma should be a single number or a list/tuple with length 2.")
15781585

1579-
self.rad_min, self.rad_max = radius
1586+
self.ksize = ksize
1587+
self.sigma_min, self.sigma_max = sigma
15801588

15811589
@staticmethod
1582-
def get_params(rad_min: float, rad_max: float):
1583-
"""Choose radius for ``gaussian_blur`` for random gaussian blurring.
1590+
def get_params(sigma_min: float, sigma_max: float):
1591+
"""Choose sigma for ``gaussian_blur`` for random gaussian blurring.
15841592
15851593
Args:
1586-
rad_min (float): Minimum radius that can be chosen for blurring kernel.
1587-
rad_max (float): Maximum radius that can be chosen for blurring kernel.
1594+
sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
1595+
sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
15881596
15891597
Returns:
1590-
float: radius be passed to ``gaussian_blur`` for gaussian blurring.
1598+
float: Standard deviation to be passed to calculate kernel for gaussian blurring.
15911599
"""
1592-
radius = random.uniform(rad_min, rad_max)
1593-
return radius
1600+
sigma = random.uniform(sigma_min, sigma_max)
1601+
return sigma
15941602

15951603
def forward(self, img):
15961604
"""
@@ -1600,8 +1608,10 @@ def forward(self, img):
16001608
Returns:
16011609
PIL Image or Tensor: Gaussian blurred image
16021610
"""
1603-
radius = self.get_params(self.rad_min, self.rad_max)
1604-
return F.gaussian_blur(img, radius)
1611+
sigma = self.get_params(self.sigma_min, self.sigma_max)
1612+
return F.gaussian_blur(img, self.ksize, sigma)
16051613

16061614
def __repr__(self):
1607-
return self.__class__.__name__ + '(rad_min={0}, rad_max={1})'.format(self.rad_min, self.rad_max)
1615+
s = 'kernel size={0}, '.format(self.ksize)
1616+
s += '(sigma_min={0}, sigma_max={1})'.format(self.sigma_min, self.sigma_max)
1617+
return self.__class__.__name__ + s

0 commit comments

Comments
 (0)