Skip to content

Commit f4a49ec

Browse files
committed
Add back support for return HDUs from tests and comparing full FITS files
1 parent d760656 commit f4a49ec

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

pytest_arraydiff/plugin.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,68 @@
3232
from functools import wraps
3333

3434
import os
35-
import sys
35+
import abc
3636
import shutil
3737
import tempfile
3838
import warnings
3939

40+
from six import add_metaclass
41+
from six.moves.urllib.request import urlopen
42+
4043
import pytest
4144
import numpy as np
4245

43-
if sys.version_info[0] == 2:
44-
from urllib import urlopen
45-
else:
46-
from urllib.request import urlopen
46+
47+
@add_metaclass(abc.ABCMeta)
48+
class BaseDiff(object):
49+
50+
@abc.abstractstaticmethod
51+
def read(filename):
52+
"""
53+
Given a filename, return a data object.
54+
"""
55+
raise NotImplementedError()
56+
57+
@abc.abstractstaticmethod
58+
def write(filename, data, **kwargs):
59+
"""
60+
Given a filename and a data object (and optional keyword arguments),
61+
write the data to a file.
62+
"""
63+
raise NotImplementedError()
64+
65+
@abc.abstractclassmethod
66+
def compare(self, reference_file, test_file, atol=None, rtol=None):
67+
"""
68+
Given a reference and test filename, compare the data to the specified
69+
absolute (``atol``) and relative (``rtol``) tolerances.
70+
71+
Should return two arguments: a boolean indicating whether the data are
72+
identical, and a string giving the full error message if not.
73+
"""
74+
raise NotImplementedError()
75+
76+
77+
class SimpleArrayDiff(BaseDiff):
78+
79+
@classmethod
80+
def compare(cls, reference_file, test_file, atol=None, rtol=None):
81+
82+
array_ref = cls.read(reference_file)
83+
array_new = cls.read(test_file)
84+
85+
try:
86+
np.testing.assert_allclose(array_ref, array_new, atol=atol, rtol=rtol)
87+
except AssertionError as exc:
88+
message = "\n\na: {0}".format(test_file) + '\n'
89+
message += "b: {0}".format(reference_file) + '\n'
90+
message += exc.args[0]
91+
return False, message
92+
else:
93+
return True, ""
4794

4895

49-
class FITSDiff(object):
96+
class FITSDiff(BaseDiff):
5097

5198
extension = 'fits'
5299

@@ -56,12 +103,20 @@ def read(filename):
56103
return fits.getdata(filename)
57104

58105
@staticmethod
59-
def write(filename, array, **kwargs):
106+
def write(filename, data, **kwargs):
60107
from astropy.io import fits
61-
return fits.writeto(filename, array, **kwargs)
108+
if isinstance(data, np.ndarray):
109+
data = fits.PrimaryHDU(data)
110+
return data.writeto(filename, **kwargs)
111+
112+
@classmethod
113+
def compare(cls, reference_file, test_file, atol=None, rtol=None):
114+
from astropy.io.fits.diff import FITSDiff
115+
diff = FITSDiff(reference_file, test_file, tolerance=rtol)
116+
return diff.identical, diff.report()
62117

63118

64-
class TextDiff(object):
119+
class TextDiff(SimpleArrayDiff):
65120

66121
extension = 'txt'
67122

@@ -70,10 +125,10 @@ def read(filename):
70125
return np.loadtxt(filename)
71126

72127
@staticmethod
73-
def write(filename, array, **kwargs):
128+
def write(filename, data, **kwargs):
74129
if 'fmt' not in kwargs:
75130
kwargs['fmt'] = '%g'
76-
return np.savetxt(filename, array, **kwargs)
131+
return np.savetxt(filename, data, **kwargs)
77132

78133

79134
FORMATS = {}
@@ -219,17 +274,12 @@ def item_function_wrapper(*args, **kwargs):
219274
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
220275
shutil.copyfile(baseline_file_ref, baseline_file)
221276

222-
array_ref = FORMATS[file_format].read(baseline_file)
277+
identical, msg = FORMATS[file_format].compare(baseline_file, test_image, atol=atol, rtol=rtol)
223278

224-
try:
225-
np.testing.assert_allclose(array_ref, array, atol=atol, rtol=rtol)
226-
except AssertionError as exc:
227-
message = "\n\na: {0}".format(test_image) + '\n'
228-
message += "b: {0}".format(baseline_file) + '\n'
229-
message += exc.args[0]
230-
raise AssertionError(message)
231-
232-
shutil.rmtree(result_dir)
279+
if identical:
280+
shutil.rmtree(result_dir)
281+
else:
282+
raise Exception(msg)
233283

234284
else:
235285

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
name="pytest-arraydiff",
1717
description='pytest plugin to help with comparing array output from tests',
1818
long_description=long_description,
19-
packages = ['pytest_arraydiff'],
19+
packages=['pytest_arraydiff'],
20+
install_requires=['numpy', 'six', 'pytest'],
2021
license='BSD',
2122
author='Thomas Robitaille',
2223
author_email='[email protected]',
23-
entry_points = {'pytest11': ['pytest_arraydiff = pytest_arraydiff.plugin',]},
24-
)
24+
entry_points={'pytest11': ['pytest_arraydiff = pytest_arraydiff.plugin']},
25+
)

tests/test_pytest_arraydiff.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def test_succeeds_func_fits():
2323
return np.arange(3 * 5).reshape((3, 5))
2424

2525

26+
@pytest.mark.array_compare(file_format='fits', reference_dir=reference_dir)
27+
def test_succeeds_func_fits_hdu():
28+
from astropy.io import fits
29+
return fits.PrimaryHDU(np.arange(3 * 5).reshape((3, 5)))
30+
31+
2632
class TestClass(object):
2733

2834
@pytest.mark.array_compare(file_format='fits', reference_dir=reference_dir)

0 commit comments

Comments
 (0)