Skip to content

Commit 3bde732

Browse files
author
Jeff Yang
authored
[metrics] speed up SSIM tests (#1467)
* Update setup.cfg * [metrics] update ssim * use np.allclose instead of torch.allclose * Apply suggestions from code review * extract into _test_ssim
1 parent 6b2f235 commit 3bde732

File tree

1 file changed

+71
-26
lines changed

1 file changed

+71
-26
lines changed

tests/ignite/metrics/test_ssim.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
import numpy as np
34
import pytest
45
import torch
56

@@ -20,14 +21,14 @@ def test_zero_div():
2021

2122

2223
def test_invalid_ssim():
23-
y_pred = torch.rand(16, 1, 32, 32)
24+
y_pred = torch.rand(1, 1, 4, 4)
2425
y = y_pred + 0.125
25-
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."):
26-
ssim = SSIM(data_range=1.0, kernel_size=10)
26+
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number."):
27+
ssim = SSIM(data_range=1.0, kernel_size=2)
2728
ssim.update((y_pred, y))
2829
ssim.compute()
2930

30-
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."):
31+
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number."):
3132
ssim = SSIM(data_range=1.0, kernel_size=-1)
3233
ssim.update((y_pred, y))
3334
ssim.compute()
@@ -42,38 +43,73 @@ def test_invalid_ssim():
4243
ssim.update((y_pred, y))
4344
ssim.compute()
4445

46+
with pytest.raises(ValueError, match=r"Expected sigma to have positive number."):
47+
ssim = SSIM(data_range=1.0, sigma=(-1, -1))
48+
ssim.update((y_pred, y))
49+
ssim.compute()
50+
4551
with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
4652
ssim = SSIM(data_range=1.0, sigma=1)
4753
ssim.update((y_pred, y))
4854
ssim.compute()
4955

56+
with pytest.raises(ValueError, match=r"Expected y_pred and y to have the same shape."):
57+
y = y.squeeze(dim=0)
58+
ssim = SSIM(data_range=1.0)
59+
ssim.update((y_pred, y))
60+
ssim.compute()
5061

51-
def test_ssim():
52-
device = "cuda" if torch.cuda.is_available() else "cpu"
53-
ssim = SSIM(data_range=1.0, device=device)
54-
y_pred = torch.rand(16, 3, 64, 64, device=device)
55-
y = y_pred * 0.65
56-
ssim.update((y_pred, y))
62+
with pytest.raises(ValueError, match=r"Expected y_pred and y to have BxCxHxW shape."):
63+
y = y.squeeze(dim=0)
64+
ssim = SSIM(data_range=1.0)
65+
ssim.update((y, y))
66+
ssim.compute()
5767

58-
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
59-
np_y = np_pred * 0.65
60-
np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)
68+
with pytest.raises(TypeError, match=r"Expected y_pred and y to have the same data type."):
69+
y = y.double()
70+
ssim = SSIM(data_range=1.0)
71+
ssim.update((y_pred, y))
72+
ssim.compute()
6173

62-
assert isinstance(ssim.compute(), torch.Tensor)
63-
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)
6474

65-
device = "cuda" if torch.cuda.is_available() else "cpu"
66-
ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7, device=device)
67-
y_pred = torch.rand(16, 3, 227, 227, device=device)
68-
y = y_pred * 0.65
75+
def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_covariance, device):
76+
atol = 7e-5
77+
ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
6978
ssim.update((y_pred, y))
79+
ignite_ssim = ssim.compute()
80+
81+
skimg_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
82+
skimg_y = skimg_pred * 0.8
83+
skimg_ssim = ski_ssim(
84+
skimg_pred,
85+
skimg_y,
86+
win_size=kernel_size,
87+
sigma=sigma,
88+
multichannel=True,
89+
gaussian_weights=gaussian,
90+
data_range=data_range,
91+
use_sample_covariance=use_sample_covariance,
92+
)
93+
94+
assert isinstance(ignite_ssim, torch.Tensor)
95+
assert ignite_ssim.dtype == torch.float64
96+
assert ignite_ssim.device == torch.device(device)
97+
assert np.allclose(ignite_ssim.numpy(), skimg_ssim, atol=atol)
7098

71-
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
72-
np_y = np_pred * 0.65
73-
np_ssim = ski_ssim(np_pred, np_y, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)
7499

75-
assert isinstance(ssim.compute(), torch.Tensor)
76-
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)
100+
def test_ssim():
101+
device = "cuda" if torch.cuda.is_available() else "cpu"
102+
y_pred = torch.rand(8, 3, 224, 224, device=device)
103+
y = y_pred * 0.8
104+
_test_ssim(
105+
y_pred, y, data_range=1.0, kernel_size=7, sigma=1.5, gaussian=False, use_sample_covariance=True, device=device
106+
)
107+
108+
y_pred = torch.rand(12, 3, 28, 28, device=device)
109+
y = y_pred * 0.8
110+
_test_ssim(
111+
y_pred, y, data_range=1.0, kernel_size=11, sigma=1.5, gaussian=True, use_sample_covariance=False, device=device
112+
)
77113

78114

79115
def _test_distrib_integration(device, tol=1e-4):
@@ -105,7 +141,16 @@ def update(engine, i):
105141

106142
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
107143
np_true = np_pred * 0.65
108-
true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)
144+
true_res = ski_ssim(
145+
np_pred,
146+
np_true,
147+
win_size=11,
148+
sigma=1.5,
149+
multichannel=True,
150+
gaussian_weights=True,
151+
data_range=1.0,
152+
use_sample_covariance=False,
153+
)
109154

110155
assert pytest.approx(res, abs=tol) == true_res
111156

@@ -142,7 +187,7 @@ def _test_distrib_accumulator_device(device):
142187
type(ssim._kernel.device), ssim._kernel.device, type(metric_device), metric_device
143188
)
144189

145-
y_pred = torch.rand(4, 3, 28, 28, dtype=torch.float, device=device)
190+
y_pred = torch.rand(2, 3, 28, 28, dtype=torch.float, device=device)
146191
y = y_pred * 0.65
147192
ssim.update((y_pred, y))
148193

0 commit comments

Comments
 (0)