Skip to content

Commit c43c414

Browse files
committed
Fixed implementation and updated tests
1 parent 8a72b87 commit c43c414

File tree

5 files changed

+216
-103
lines changed

5 files changed

+216
-103
lines changed

test/test_functional_tensor.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,14 +678,14 @@ def test_rotate(self):
678678
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center
679679
)
680680

681-
def _test_perspective(self, tensor, pil_img, scripted_tranform, test_configs):
681+
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
682682
dt = tensor.dtype
683683
for r in [0, ]:
684684
for spoints, epoints in test_configs:
685685
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
686686
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
687687

688-
for fn in [F.perspective, scripted_tranform]:
688+
for fn in [F.perspective, scripted_transform]:
689689
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
690690

691691
if out_tensor.dtype != torch.uint8:
@@ -710,7 +710,7 @@ def test_perspective(self):
710710
from torchvision.transforms import RandomPerspective
711711

712712
data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
713-
scripted_tranform = torch.jit.script(F.perspective)
713+
scripted_transform = torch.jit.script(F.perspective)
714714

715715
for tensor, pil_img in data:
716716

@@ -733,7 +733,7 @@ def test_perspective(self):
733733
if dt is not None:
734734
tensor = tensor.to(dtype=dt)
735735

736-
self._test_perspective(tensor, pil_img, scripted_tranform, test_configs)
736+
self._test_perspective(tensor, pil_img, scripted_transform, test_configs)
737737

738738
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
739739
if dt is not None:
@@ -744,6 +744,114 @@ def test_perspective(self):
744744
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
745745
)
746746

747+
def test_gaussian_blur(self):
748+
tensor = torch.from_numpy(
749+
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
750+
).permute(2, 0, 1).to(self.device)
751+
752+
scripted_transform = torch.jit.script(F.gaussian_blur)
753+
754+
true_cv2_results = {
755+
# cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
756+
"3_3_0.8":
757+
[19, 20, 21, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
758+
43, 44, 45, 46, 47, 48, 49, 49, 50, 51, 37, 38, 39, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
759+
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 68, 69, 70, 73, 74, 75, 75, 76, 77,
760+
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102,
761+
103, 104, 104, 105, 106, 109, 110, 111, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
762+
123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 140, 141,
763+
142, 145, 146, 147, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
764+
163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 176, 177, 178, 181, 182, 183,
765+
183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202,
766+
203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 212, 213, 214, 217, 189, 190, 204, 174, 175, 176,
767+
162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181,
768+
182, 183, 184, 185, 186, 187, 187, 188, 189, 192, 130, 131, 162, 93, 94, 95, 64, 65, 66, 67, 68, 69,
769+
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 89, 90, 91, 94, 66, 67,
770+
81, 51, 52, 53, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
771+
61, 62, 63, 64, 64, 65, 66, 52, 53, 54, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
772+
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 82, 83, 84],
773+
# cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
774+
"3_3_0.5":
775+
[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
776+
35, 36, 37, 38, 39, 40, 40, 41, 42, 37, 38, 39, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
777+
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 68, 69, 70, 73, 74, 75, 75, 76, 77, 78, 79, 80,
778+
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 104,
779+
105, 106, 109, 110, 111, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
780+
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 140, 141, 142, 145, 146, 147, 147,
781+
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
782+
169, 170, 171, 172, 173, 174, 175, 176, 176, 177, 178, 181, 182, 183, 183, 184, 185, 186, 187, 188, 189,
783+
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
784+
211, 212, 212, 213, 214, 217, 212, 213, 216, 196, 197, 198, 196, 197, 198, 199, 200, 201, 202, 203, 204,
785+
205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 221, 222, 223, 226,
786+
184, 185, 207, 48, 49, 50, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
787+
50, 51, 52, 53, 54, 55, 55, 56, 57, 60, 55, 56, 59, 39, 40, 41, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
788+
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 64, 65, 66, 61, 62, 63, 63, 64, 65, 66, 67, 68,
789+
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95
790+
],
791+
# cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
792+
"3_5_0.8":
793+
[21, 22, 23, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
794+
46, 47, 48, 49, 50, 51, 52, 51, 52, 53, 39, 40, 41, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
795+
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 69, 70, 71, 73, 74, 75, 75, 76, 77,
796+
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102,
797+
103, 104, 104, 105, 106, 109, 110, 111, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
798+
123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 140, 141,
799+
142, 145, 146, 147, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
800+
163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 176, 177, 178, 181, 180, 181,
801+
182, 179, 180, 181, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
802+
197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 206, 207, 208, 211, 185, 186, 199, 170, 171, 172,
803+
159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
804+
179, 180, 181, 182, 183, 184, 184, 185, 186, 189, 129, 130, 161, 95, 96, 97, 67, 68, 69, 70, 71, 72,
805+
73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 92, 93, 94, 96, 69, 70,
806+
83, 54, 55, 56, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
807+
65, 66, 67, 68, 68, 69, 70, 62, 57, 58, 60, 55, 56, 57, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
808+
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 80, 81, 82],
809+
# cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
810+
"3_5_0.5":
811+
[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
812+
34, 35, 36, 37, 38, 39, 40, 40, 41, 42, 37, 38, 39, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
813+
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 68, 69, 70, 73, 74, 75, 75, 76, 77,
814+
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102,
815+
103, 104, 104, 105, 106, 109, 110, 111, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
816+
123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 140, 141,
817+
142, 145, 146, 147, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
818+
163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 176, 177, 178, 181, 182, 183,
819+
183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202,
820+
203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 212, 213, 214, 217, 212, 213, 216, 196, 197, 198,
821+
196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215,
822+
216, 217, 218, 219, 220, 221, 221, 222, 223, 226, 184, 185, 207, 48, 49, 50, 30, 31, 32, 33, 34, 35,
823+
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 55, 56, 57, 60, 55, 56,
824+
59, 39, 40, 41, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
825+
61, 62, 63, 64, 64, 65, 66, 61, 62, 63, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
826+
79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]
827+
}
828+
829+
for dt in [None, torch.float32, torch.float64, torch.float16]:
830+
if dt == torch.float16 and torch.device(self.device).type == "cpu":
831+
# skip float16 on CPU case
832+
continue
833+
834+
if dt is not None:
835+
tensor = tensor.to(dtype=dt)
836+
837+
for ksize in [(3, 3), [3, 5]]:
838+
for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8)]:
839+
840+
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
841+
_sigma = sigma[0] if sigma is not None else None
842+
true_out = torch.tensor(
843+
true_cv2_results["{}_{}_{}".format(_ksize[0], _ksize[1], _sigma)]
844+
).reshape(10, 12, 3).permute(2, 0, 1)
845+
846+
for fn in [F.gaussian_blur, scripted_transform]:
847+
out = fn(tensor, kernel_size=ksize, sigma=sigma)
848+
self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
849+
self.assertLessEqual(
850+
torch.max(true_out.float() - out.float()),
851+
1.0,
852+
msg="{}, {}".format(ksize, sigma)
853+
)
854+
747855

748856
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
749857
class CUDATester(Tester):

test/test_transforms_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,24 @@ def test_compose(self):
433433
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
434434
torch.jit.script(t)
435435

436+
def test_gaussian_blur(self):
437+
438+
tol = 1.0 + 1e-10
439+
self._test_class_op(
440+
"GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
441+
test_exact_match=False, agg_method="max", tol=tol
442+
)
443+
444+
self._test_class_op(
445+
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
446+
test_exact_match=False, agg_method="max", tol=tol
447+
)
448+
449+
self._test_class_op(
450+
"GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
451+
test_exact_match=False, agg_method="max", tol=tol
452+
)
453+
436454

437455
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
438456
class CUDATester(Tester):

torchvision/transforms/functional.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,8 +1038,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
10381038
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
10391039
like ``(kx, ky)`` or a single integer for square kernels.
10401040
In torchscript mode kernel_size as single int is not supported, use a tuple or
1041-
list of length 1: ``[size, ]``.
1042-
sigma (sequence of floats or float or None, optional): Gaussian kernel standard deviation. Can be a
1041+
list of length 1: ``[ksize, ]``.
1042+
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
10431043
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
10441044
same sigma in both X/Y directions. If None, then it is computed using
10451045
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
@@ -1049,17 +1049,40 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
10491049
Returns:
10501050
PIL Image or Tensor: Gaussian Blurred version of the image.
10511051
"""
1052-
is_pil_image = False
1052+
if not isinstance(kernel_size, (int, list, tuple)):
1053+
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
1054+
if isinstance(kernel_size, int):
1055+
kernel_size = [kernel_size, kernel_size]
1056+
if len(kernel_size) != 2:
1057+
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
1058+
for ksize in kernel_size:
1059+
if ksize % 2 == 0 or ksize < 0:
1060+
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
1061+
1062+
if sigma is None:
1063+
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
1064+
1065+
if sigma is not None and not isinstance(sigma, (float, list, tuple)):
1066+
raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
1067+
if isinstance(sigma, float):
1068+
sigma = [sigma, sigma]
1069+
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
1070+
sigma = [sigma[0], sigma[0]]
1071+
if len(sigma) != 2:
1072+
raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
1073+
for s in sigma:
1074+
if s <= 0.:
1075+
raise ValueError('sigma should have positive values. Got {}'.format(sigma))
1076+
10531077
t_img = img
10541078
if not isinstance(img, torch.Tensor):
10551079
if not F_pil._is_pil_image(img):
10561080
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
10571081

1058-
is_pil_image = True
10591082
t_img = to_tensor(img)
10601083

10611084
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
10621085

1063-
if is_pil_image:
1086+
if not isinstance(img, torch.Tensor):
10641087
output = to_pil_image(output)
10651088
return output

0 commit comments

Comments
 (0)