@@ -942,65 +942,78 @@ def perspective(
942
942
return _apply_grid_transform (img , grid , mode )
943
943
944
944
945
- def _get_kernel ( radius : float , passes : int ):
946
- sigma2 = torch . Tensor ([ radius ** 2 / passes ])
945
+ def _get_gaussian_kernel1d ( kernel_size : int , sigma : float ):
946
+ ksize_half = ( kernel_size - 1 ) * 0.5
947
947
948
- kernel_rad = (torch .sqrt (12. * sigma2 + 1. ) - 1. ) / 2.
948
+ x = torch .linspace (- ksize_half , ksize_half , steps = kernel_size )
949
+ pdf = torch .exp (- 0.5 * (x / sigma ).pow (2 ))
950
+ kernel1d = pdf / pdf .sum ()
949
951
950
- kernel_rad_int = kernel_rad . long (). item ()
952
+ return kernel1d
951
953
952
- kernel_rad_float = (2 * kernel_rad_int + 1 ) * (kernel_rad_int * (kernel_rad_int + 1 ) - 3 * sigma2 )
953
- kernel_rad_float /= 6 * (sigma2 - (kernel_rad_int + 1 ) * (kernel_rad_int + 1 ))
954
- kernel_rad_float = kernel_rad_float .item ()
955
954
956
- kernel_rad = kernel_rad_int + kernel_rad_float
955
+ def _get_gaussian_kernel2d (kernel_size : List [int ], sigma : List [float ]):
956
+ ksize_x , ksize_y = kernel_size
957
+ sigma_x , sigma_y = sigma
957
958
958
- ksize = 2 * kernel_rad_int + 1 + 2 * ( kernel_rad_float > 0 )
959
- kernel1d = torch . ones ( ksize ) / ( 2 * kernel_rad + 1 )
959
+ kernel1d_x = _get_gaussian_kernel1d ( ksize_x , sigma_x )
960
+ kernel1d_y = _get_gaussian_kernel1d ( ksize_y , sigma_y )
960
961
961
- if kernel_rad_float > 0 :
962
- kernel1d [[0 , - 1 ]] = kernel_rad_float / (2 * kernel_rad + 1 )
963
-
964
- kernel2d = torch .mm (kernel1d [:, None ], kernel1d [None , :])
962
+ kernel2d = torch .mm (kernel1d_y [:, None ], kernel1d_x [None , :])
965
963
966
964
return kernel2d
967
965
968
966
969
- def gaussian_blur (img : Tensor , radius : float ) -> Tensor :
967
+ def gaussian_blur (img : Tensor , kernel_size : List [ int ], sigma : Optional [ List [ float ]] = None ) -> Tensor :
970
968
"""Performs Gaussian blurring on the img by given kernel.
971
969
972
970
Args:
973
971
img (Tensor): Image to be blurred
974
- radius (float): Blur radius
972
+ kernel_size (sequence of int or int): Kernel size of the Gaussian kernel
973
+ sigma (sequence of float or float or None): Standard deviation of the Gaussian kernel
975
974
976
975
Returns:
977
- Tensor: An image that is blurred using kernel of given radius
976
+ Tensor: An image that is blurred using gaussian kernel of given parameters
978
977
"""
979
- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image (img )):
978
+ if not (isinstance (img , torch .Tensor ) or _is_tensor_a_torch_image (img )):
980
979
raise TypeError ('img should be Tensor Image. Got {}' .format (type (img )))
981
- if not isinstance (radius , (float , int )):
982
- raise TypeError ('radius should be either float or int. Got {}' .format (type (radius )))
980
+ if not isinstance (kernel_size , (int , list , tuple )):
981
+ raise TypeError ('kernel_size should be int or a sequence of integers. Got {}' .format (type (kernel_size )))
982
+ if not isinstance (sigma , (float , int , list , tuple )) and sigma != None :
983
+ raise TypeError ('sigma should be either float or int or its sequence. Got {}' .format (type (sigma )))
984
+
985
+ if isinstance (kernel_size , int ):
986
+ kernel_size = [kernel_size ] * 2
987
+ if isinstance (sigma , (int , float , None )):
988
+ sigma = [sigma ] * 2
989
+
990
+ if len (kernel_size ) != 2 :
991
+ raise ValueError ('If kernel_size is a sequence its length should be 2. Got {}' .format (len (kernel_size )))
992
+ if len (sigma ) != 2 :
993
+ raise ValueError ('If sigma is a sequence its length should be 2. Got {}' .format (len (sigma )))
994
+
995
+ if any ([ksize % 2 == 0 or not isinstance (ksize , int ) for ksize in kernel_size ]):
996
+ raise ValueError ('kernel_size should have odd and positive integers. Got {}' .format (kernel_size ))
997
+
998
+ sigma = [s if s != None else 0.3 * ((ksize - 1 ) * 0.5 - 1 ) + 0.8 for ksize , s in zip (kernel_size , sigma )]
983
999
984
- radius = float ( radius )
985
- passes = 3
1000
+ if any ([ s <= 0. for s in sigma ]):
1001
+ raise ValueError ( 'sigma should have positive values. Got {}' . format ( sigma ))
986
1002
987
1003
ndim = img .ndim
988
1004
if ndim == 2 :
989
1005
img = img .unsqueeze (0 )
990
1006
if ndim == 3 :
991
1007
img = img .unsqueeze (0 )
992
1008
993
- kernel = _get_kernel ( radius , passes )
1009
+ kernel = _get_gaussian_kernel2d ( kernel_size , sigma )
994
1010
995
- padding = _compute_padding (kernel . shape [:: - 1 ] )
1011
+ padding = _compute_padding (kernel_size )
996
1012
997
1013
kernel = kernel [None , None , :, :].repeat (img .size (- 3 ), 1 , 1 , 1 )
998
1014
999
- padded_img = pad (img , padding , padding_mode = 'edge ' )
1015
+ padded_img = pad (img , padding , padding_mode = 'reflect ' )
1000
1016
blurred_img = conv2d (padded_img , kernel , groups = img .size (- 3 ))
1001
- for _ in range (passes - 1 ):
1002
- padded_img = pad (blurred_img , padding , padding_mode = 'edge' )
1003
- blurred_img = conv2d (padded_img , kernel , groups = img .size (- 3 ))
1004
1017
1005
1018
if ndim == 2 :
1006
1019
return blurred_img [0 , 0 ]
0 commit comments