Skip to content

Commit 27ebe08

Browse files
authored
Merge pull request #5 from astrofrog/fits-hdu-support
Add back FITS HDU support
2 parents d760656 + 0036e2a commit 27ebe08

File tree

5 files changed

+100
-32
lines changed

5 files changed

+100
-32
lines changed

README.rst

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@ About
44
-----
55

66
This is a `py.test <http://pytest.org>`__ plugin to facilitate the
7-
generation and comparison of arrays produced during tests (this is a
7+
generation and comparison of data arrays produced during tests (this is a
88
spin-off from
99
`pytest-arraydiff <https://github.com/astrofrog/pytest-arraydiff>`__).
1010

11-
The basic idea is that you can write a test that generates a Numpy
12-
array. You can then either run the tests in a mode to **generate**
13-
reference files from the arrays, or you can run the tests in
14-
**comparison** mode, which will compare the results of the tests to the
15-
reference ones within some tolerance.
11+
The basic idea is that you can write a test that generates a Numpy array (or
12+
other related objects depending on the format). You can then either run the
13+
tests in a mode to **generate** reference files from the arrays, or you can run
14+
the tests in **comparison** mode, which will compare the results of the tests to
15+
the reference ones within some tolerance.
1616

1717
At the moment, the supported file formats for the reference files are:
1818

19-
- The FITS format (requires `astropy <http://www.astropy.org>`__)
2019
- A plain text-based format (baed on Numpy ``loadtxt`` output)
20+
- The FITS format (requires `astropy <http://www.astropy.org>`__). With this
21+
format, tests can return either a Numpy array for a FITS HDU object.
2122

2223
For more information on how to write tests to do this, see the **Using**
2324
section below.
@@ -66,7 +67,7 @@ function returns a plain Numpy array:
6667
def test_succeeds():
6768
return np.arange(3 * 5 * 4).reshape((3, 5, 4))
6869

69-
To generate the reference FITS files, run the tests with the
70+
To generate the reference data files, run the tests with the
7071
``--arraydiff-generate-path`` option with the name of the directory
7172
where the generated files should be placed:
7273

@@ -79,7 +80,7 @@ be interpreted as being relative to where you are running ``py.test``.
7980
Make sure you manually check the reference images to ensure they are
8081
correct.
8182

82-
Once you are happy with the generated FITS files, you should move them
83+
Once you are happy with the generated data files, you should move them
8384
to a sub-directory called ``reference`` relative to the test files (this
8485
name is configurable, see below). You can also generate the baseline
8586
images directly in the right directory.

pytest_arraydiff/plugin.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,78 @@
3232
from functools import wraps
3333

3434
import os
35-
import sys
35+
import abc
3636
import shutil
3737
import tempfile
3838
import warnings
3939

40+
from six import add_metaclass, PY2
41+
from six.moves.urllib.request import urlopen
42+
4043
import pytest
4144
import numpy as np
4245

43-
if sys.version_info[0] == 2:
44-
from urllib import urlopen
46+
47+
if PY2:
48+
def abstractstaticmethod(func):
49+
return func
50+
def abstractclassmethod(func):
51+
return func
4552
else:
46-
from urllib.request import urlopen
53+
abstractstaticmethod = abc.abstractstaticmethod
54+
abstractclassmethod = abc.abstractclassmethod
55+
56+
57+
@add_metaclass(abc.ABCMeta)
58+
class BaseDiff(object):
59+
60+
@abstractstaticmethod
61+
def read(filename):
62+
"""
63+
Given a filename, return a data object.
64+
"""
65+
raise NotImplementedError()
66+
67+
@abstractstaticmethod
68+
def write(filename, data, **kwargs):
69+
"""
70+
Given a filename and a data object (and optional keyword arguments),
71+
write the data to a file.
72+
"""
73+
raise NotImplementedError()
74+
75+
@abstractclassmethod
76+
def compare(self, reference_file, test_file, atol=None, rtol=None):
77+
"""
78+
Given a reference and test filename, compare the data to the specified
79+
absolute (``atol``) and relative (``rtol``) tolerances.
80+
81+
Should return two arguments: a boolean indicating whether the data are
82+
identical, and a string giving the full error message if not.
83+
"""
84+
raise NotImplementedError()
85+
86+
87+
class SimpleArrayDiff(BaseDiff):
88+
89+
@classmethod
90+
def compare(cls, reference_file, test_file, atol=None, rtol=None):
91+
92+
array_ref = cls.read(reference_file)
93+
array_new = cls.read(test_file)
94+
95+
try:
96+
np.testing.assert_allclose(array_ref, array_new, atol=atol, rtol=rtol)
97+
except AssertionError as exc:
98+
message = "\n\na: {0}".format(test_file) + '\n'
99+
message += "b: {0}".format(reference_file) + '\n'
100+
message += exc.args[0]
101+
return False, message
102+
else:
103+
return True, ""
47104

48105

49-
class FITSDiff(object):
106+
class FITSDiff(BaseDiff):
50107

51108
extension = 'fits'
52109

@@ -56,12 +113,20 @@ def read(filename):
56113
return fits.getdata(filename)
57114

58115
@staticmethod
59-
def write(filename, array, **kwargs):
116+
def write(filename, data, **kwargs):
60117
from astropy.io import fits
61-
return fits.writeto(filename, array, **kwargs)
118+
if isinstance(data, np.ndarray):
119+
data = fits.PrimaryHDU(data)
120+
return data.writeto(filename, **kwargs)
121+
122+
@classmethod
123+
def compare(cls, reference_file, test_file, atol=None, rtol=None):
124+
from astropy.io.fits.diff import FITSDiff
125+
diff = FITSDiff(reference_file, test_file, tolerance=rtol)
126+
return diff.identical, diff.report()
62127

63128

64-
class TextDiff(object):
129+
class TextDiff(SimpleArrayDiff):
65130

66131
extension = 'txt'
67132

@@ -70,10 +135,10 @@ def read(filename):
70135
return np.loadtxt(filename)
71136

72137
@staticmethod
73-
def write(filename, array, **kwargs):
138+
def write(filename, data, **kwargs):
74139
if 'fmt' not in kwargs:
75140
kwargs['fmt'] = '%g'
76-
return np.savetxt(filename, array, **kwargs)
141+
return np.savetxt(filename, data, **kwargs)
77142

78143

79144
FORMATS = {}
@@ -219,17 +284,12 @@ def item_function_wrapper(*args, **kwargs):
219284
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
220285
shutil.copyfile(baseline_file_ref, baseline_file)
221286

222-
array_ref = FORMATS[file_format].read(baseline_file)
287+
identical, msg = FORMATS[file_format].compare(baseline_file, test_image, atol=atol, rtol=rtol)
223288

224-
try:
225-
np.testing.assert_allclose(array_ref, array, atol=atol, rtol=rtol)
226-
except AssertionError as exc:
227-
message = "\n\na: {0}".format(test_image) + '\n'
228-
message += "b: {0}".format(baseline_file) + '\n'
229-
message += exc.args[0]
230-
raise AssertionError(message)
231-
232-
shutil.rmtree(result_dir)
289+
if identical:
290+
shutil.rmtree(result_dir)
291+
else:
292+
raise Exception(msg)
233293

234294
else:
235295

setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
name="pytest-arraydiff",
1717
description='pytest plugin to help with comparing array output from tests',
1818
long_description=long_description,
19-
packages = ['pytest_arraydiff'],
19+
packages=['pytest_arraydiff'],
20+
install_requires=['numpy', 'six', 'pytest'],
2021
license='BSD',
2122
author='Thomas Robitaille',
2223
author_email='[email protected]',
23-
entry_points = {'pytest11': ['pytest_arraydiff = pytest_arraydiff.plugin',]},
24-
)
24+
entry_points={'pytest11': ['pytest_arraydiff = pytest_arraydiff.plugin']},
25+
)
5.63 KB
Binary file not shown.

tests/test_pytest_arraydiff.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def test_succeeds_func_fits():
2323
return np.arange(3 * 5).reshape((3, 5))
2424

2525

26+
@pytest.mark.array_compare(file_format='fits', reference_dir=reference_dir)
27+
def test_succeeds_func_fits_hdu():
28+
from astropy.io import fits
29+
return fits.PrimaryHDU(np.arange(3 * 5).reshape((3, 5)))
30+
31+
2632
class TestClass(object):
2733

2834
@pytest.mark.array_compare(file_format='fits', reference_dir=reference_dir)

0 commit comments

Comments
 (0)