1
1
import os
2
2
3
+ import numpy as np
3
4
import pytest
4
5
import torch
5
6
@@ -20,14 +21,14 @@ def test_zero_div():
20
21
21
22
22
23
def test_invalid_ssim ():
23
- y_pred = torch .rand (16 , 1 , 32 , 32 )
24
+ y_pred = torch .rand (1 , 1 , 4 , 4 )
24
25
y = y_pred + 0.125
25
- with pytest .raises (ValueError , match = r"Expected kernel_size to have odd positive number. Got 10. " ):
26
- ssim = SSIM (data_range = 1.0 , kernel_size = 10 )
26
+ with pytest .raises (ValueError , match = r"Expected kernel_size to have odd positive number." ):
27
+ ssim = SSIM (data_range = 1.0 , kernel_size = 2 )
27
28
ssim .update ((y_pred , y ))
28
29
ssim .compute ()
29
30
30
- with pytest .raises (ValueError , match = r"Expected kernel_size to have odd positive number. Got -1. " ):
31
+ with pytest .raises (ValueError , match = r"Expected kernel_size to have odd positive number." ):
31
32
ssim = SSIM (data_range = 1.0 , kernel_size = - 1 )
32
33
ssim .update ((y_pred , y ))
33
34
ssim .compute ()
@@ -42,38 +43,73 @@ def test_invalid_ssim():
42
43
ssim .update ((y_pred , y ))
43
44
ssim .compute ()
44
45
46
+ with pytest .raises (ValueError , match = r"Expected sigma to have positive number." ):
47
+ ssim = SSIM (data_range = 1.0 , sigma = (- 1 , - 1 ))
48
+ ssim .update ((y_pred , y ))
49
+ ssim .compute ()
50
+
45
51
with pytest .raises (ValueError , match = r"Argument sigma should be either float or a sequence of float." ):
46
52
ssim = SSIM (data_range = 1.0 , sigma = 1 )
47
53
ssim .update ((y_pred , y ))
48
54
ssim .compute ()
49
55
56
+ with pytest .raises (ValueError , match = r"Expected y_pred and y to have the same shape." ):
57
+ y = y .squeeze (dim = 0 )
58
+ ssim = SSIM (data_range = 1.0 )
59
+ ssim .update ((y_pred , y ))
60
+ ssim .compute ()
50
61
51
- def test_ssim ():
52
- device = "cuda" if torch .cuda .is_available () else "cpu"
53
- ssim = SSIM (data_range = 1.0 , device = device )
54
- y_pred = torch .rand (16 , 3 , 64 , 64 , device = device )
55
- y = y_pred * 0.65
56
- ssim .update ((y_pred , y ))
62
+ with pytest .raises (ValueError , match = r"Expected y_pred and y to have BxCxHxW shape." ):
63
+ y = y .squeeze (dim = 0 )
64
+ ssim = SSIM (data_range = 1.0 )
65
+ ssim .update ((y , y ))
66
+ ssim .compute ()
57
67
58
- np_pred = y_pred .permute (0 , 2 , 3 , 1 ).cpu ().numpy ()
59
- np_y = np_pred * 0.65
60
- np_ssim = ski_ssim (np_pred , np_y , win_size = 11 , multichannel = True , gaussian_weights = True , data_range = 1.0 )
68
+ with pytest .raises (TypeError , match = r"Expected y_pred and y to have the same data type." ):
69
+ y = y .double ()
70
+ ssim = SSIM (data_range = 1.0 )
71
+ ssim .update ((y_pred , y ))
72
+ ssim .compute ()
61
73
62
- assert isinstance (ssim .compute (), torch .Tensor )
63
- assert torch .allclose (ssim .compute (), torch .tensor (np_ssim , dtype = torch .float64 , device = device ), atol = 1e-4 )
64
74
65
- device = "cuda" if torch .cuda .is_available () else "cpu"
66
- ssim = SSIM (data_range = 1.0 , gaussian = False , kernel_size = 7 , device = device )
67
- y_pred = torch .rand (16 , 3 , 227 , 227 , device = device )
68
- y = y_pred * 0.65
75
+ def _test_ssim (y_pred , y , data_range , kernel_size , sigma , gaussian , use_sample_covariance , device ):
76
+ atol = 7e-5
77
+ ssim = SSIM (data_range = data_range , sigma = sigma , device = device )
69
78
ssim .update ((y_pred , y ))
79
+ ignite_ssim = ssim .compute ()
80
+
81
+ skimg_pred = y_pred .permute (0 , 2 , 3 , 1 ).cpu ().numpy ()
82
+ skimg_y = skimg_pred * 0.8
83
+ skimg_ssim = ski_ssim (
84
+ skimg_pred ,
85
+ skimg_y ,
86
+ win_size = kernel_size ,
87
+ sigma = sigma ,
88
+ multichannel = True ,
89
+ gaussian_weights = gaussian ,
90
+ data_range = data_range ,
91
+ use_sample_covariance = use_sample_covariance ,
92
+ )
93
+
94
+ assert isinstance (ignite_ssim , torch .Tensor )
95
+ assert ignite_ssim .dtype == torch .float64
96
+ assert ignite_ssim .device == torch .device (device )
97
+ assert np .allclose (ignite_ssim .numpy (), skimg_ssim , atol = atol )
70
98
71
- np_pred = y_pred .permute (0 , 2 , 3 , 1 ).cpu ().numpy ()
72
- np_y = np_pred * 0.65
73
- np_ssim = ski_ssim (np_pred , np_y , win_size = 7 , multichannel = True , gaussian_weights = False , data_range = 1.0 )
74
99
75
- assert isinstance (ssim .compute (), torch .Tensor )
76
- assert torch .allclose (ssim .compute (), torch .tensor (np_ssim , dtype = torch .float64 , device = device ), atol = 1e-4 )
100
+ def test_ssim ():
101
+ device = "cuda" if torch .cuda .is_available () else "cpu"
102
+ y_pred = torch .rand (8 , 3 , 224 , 224 , device = device )
103
+ y = y_pred * 0.8
104
+ _test_ssim (
105
+ y_pred , y , data_range = 1.0 , kernel_size = 7 , sigma = 1.5 , gaussian = False , use_sample_covariance = True , device = device
106
+ )
107
+
108
+ y_pred = torch .rand (12 , 3 , 28 , 28 , device = device )
109
+ y = y_pred * 0.8
110
+ _test_ssim (
111
+ y_pred , y , data_range = 1.0 , kernel_size = 11 , sigma = 1.5 , gaussian = True , use_sample_covariance = False , device = device
112
+ )
77
113
78
114
79
115
def _test_distrib_integration (device , tol = 1e-4 ):
@@ -105,7 +141,16 @@ def update(engine, i):
105
141
106
142
np_pred = y_pred .permute (0 , 2 , 3 , 1 ).cpu ().numpy ()
107
143
np_true = np_pred * 0.65
108
- true_res = ski_ssim (np_pred , np_true , win_size = 11 , multichannel = True , gaussian_weights = True , data_range = 1.0 )
144
+ true_res = ski_ssim (
145
+ np_pred ,
146
+ np_true ,
147
+ win_size = 11 ,
148
+ sigma = 1.5 ,
149
+ multichannel = True ,
150
+ gaussian_weights = True ,
151
+ data_range = 1.0 ,
152
+ use_sample_covariance = False ,
153
+ )
109
154
110
155
assert pytest .approx (res , abs = tol ) == true_res
111
156
@@ -142,7 +187,7 @@ def _test_distrib_accumulator_device(device):
142
187
type (ssim ._kernel .device ), ssim ._kernel .device , type (metric_device ), metric_device
143
188
)
144
189
145
- y_pred = torch .rand (4 , 3 , 28 , 28 , dtype = torch .float , device = device )
190
+ y_pred = torch .rand (2 , 3 , 28 , 28 , dtype = torch .float , device = device )
146
191
y = y_pred * 0.65
147
192
ssim .update ((y_pred , y ))
148
193
0 commit comments