diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py index b158c181e79b..3c77dda13f56 100644 --- a/tests/ignite/metrics/test_ssim.py +++ b/tests/ignite/metrics/test_ssim.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest import torch @@ -20,14 +21,14 @@ def test_zero_div(): def test_invalid_ssim(): - y_pred = torch.rand(16, 1, 32, 32) + y_pred = torch.rand(1, 1, 4, 4) y = y_pred + 0.125 - with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."): - ssim = SSIM(data_range=1.0, kernel_size=10) + with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number."): + ssim = SSIM(data_range=1.0, kernel_size=2) ssim.update((y_pred, y)) ssim.compute() - with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."): + with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number."): ssim = SSIM(data_range=1.0, kernel_size=-1) ssim.update((y_pred, y)) ssim.compute() @@ -42,38 +43,73 @@ def test_invalid_ssim(): ssim.update((y_pred, y)) ssim.compute() + with pytest.raises(ValueError, match=r"Expected sigma to have positive number."): + ssim = SSIM(data_range=1.0, sigma=(-1, -1)) + ssim.update((y_pred, y)) + ssim.compute() + with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."): ssim = SSIM(data_range=1.0, sigma=1) ssim.update((y_pred, y)) ssim.compute() + with pytest.raises(ValueError, match=r"Expected y_pred and y to have the same shape."): + y = y.squeeze(dim=0) + ssim = SSIM(data_range=1.0) + ssim.update((y_pred, y)) + ssim.compute() -def test_ssim(): - device = "cuda" if torch.cuda.is_available() else "cpu" - ssim = SSIM(data_range=1.0, device=device) - y_pred = torch.rand(16, 3, 64, 64, device=device) - y = y_pred * 0.65 - ssim.update((y_pred, y)) + with pytest.raises(ValueError, match=r"Expected y_pred and y to have BxCxHxW shape."): + y = y.squeeze(dim=0) + ssim = SSIM(data_range=1.0) + ssim.update((y, y)) + ssim.compute() - np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() - np_y = np_pred * 0.65 - np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0) + with pytest.raises(TypeError, match=r"Expected y_pred and y to have the same data type."): + y = y.double() + ssim = SSIM(data_range=1.0) + ssim.update((y_pred, y)) + ssim.compute() - assert isinstance(ssim.compute(), torch.Tensor) - assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4) - device = "cuda" if torch.cuda.is_available() else "cpu" - ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7, device=device) - y_pred = torch.rand(16, 3, 227, 227, device=device) - y = y_pred * 0.65 +def _test_ssim(y_pred, y, data_range, kernel_size, sigma, gaussian, use_sample_covariance, device): + atol = 7e-5 + ssim = SSIM(data_range=data_range, sigma=sigma, device=device) ssim.update((y_pred, y)) + ignite_ssim = ssim.compute() + + skimg_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() + skimg_y = skimg_pred * 0.8 + skimg_ssim = ski_ssim( + skimg_pred, + skimg_y, + win_size=kernel_size, + sigma=sigma, + multichannel=True, + gaussian_weights=gaussian, + data_range=data_range, + use_sample_covariance=use_sample_covariance, + ) + + assert isinstance(ignite_ssim, torch.Tensor) + assert ignite_ssim.dtype == torch.float64 + assert ignite_ssim.device == torch.device(device) + assert np.allclose(ignite_ssim.numpy(), skimg_ssim, atol=atol) - np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() - np_y = np_pred * 0.65 - np_ssim = ski_ssim(np_pred, np_y, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0) - assert isinstance(ssim.compute(), torch.Tensor) - assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4) +def test_ssim(): + device = "cuda" if torch.cuda.is_available() else "cpu" + y_pred = torch.rand(8, 3, 224, 224, device=device) + y = y_pred * 0.8 + _test_ssim( + y_pred, y, data_range=1.0, kernel_size=7, sigma=1.5, gaussian=False, use_sample_covariance=True, device=device + ) + + y_pred = torch.rand(12, 3, 28, 28, device=device) + y = y_pred * 0.8 + _test_ssim( + y_pred, y, data_range=1.0, kernel_size=11, sigma=1.5, gaussian=True, use_sample_covariance=False, device=device + ) def _test_distrib_integration(device, tol=1e-4): @@ -105,7 +141,16 @@ def update(engine, i): np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() np_true = np_pred * 0.65 - true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0) + true_res = ski_ssim( + np_pred, + np_true, + win_size=11, + sigma=1.5, + multichannel=True, + gaussian_weights=True, + data_range=1.0, + use_sample_covariance=False, + ) assert pytest.approx(res, abs=tol) == true_res @@ -142,7 +187,7 @@ def _test_distrib_accumulator_device(device): type(ssim._kernel.device), ssim._kernel.device, type(metric_device), metric_device ) - y_pred = torch.rand(4, 3, 28, 28, dtype=torch.float, device=device) + y_pred = torch.rand(2, 3, 28, 28, dtype=torch.float, device=device) y = y_pred * 0.65 ssim.update((y_pred, y))