32
32
from functools import wraps
33
33
34
34
import os
35
- import sys
35
+ import abc
36
36
import shutil
37
37
import tempfile
38
38
import warnings
39
39
40
+ from six import add_metaclass
41
+ from six .moves .urllib .request import urlopen
42
+
40
43
import pytest
41
44
import numpy as np
42
45
43
- if sys .version_info [0 ] == 2 :
44
- from urllib import urlopen
45
- else :
46
- from urllib .request import urlopen
46
+
47
+ @add_metaclass (abc .ABCMeta )
48
+ class BaseDiff (object ):
49
+
50
+ @abc .abstractstaticmethod
51
+ def read (filename ):
52
+ """
53
+ Given a filename, return a data object.
54
+ """
55
+ raise NotImplementedError ()
56
+
57
+ @abc .abstractstaticmethod
58
+ def write (filename , data , ** kwargs ):
59
+ """
60
+ Given a filename and a data object (and optional keyword arguments),
61
+ write the data to a file.
62
+ """
63
+ raise NotImplementedError ()
64
+
65
+ @abc .abstractclassmethod
66
+ def compare (self , reference_file , test_file , atol = None , rtol = None ):
67
+ """
68
+ Given a reference and test filename, compare the data to the specified
69
+ absolute (``atol``) and relative (``rtol``) tolerances.
70
+
71
+ Should return two arguments: a boolean indicating whether the data are
72
+ identical, and a string giving the full error message if not.
73
+ """
74
+ raise NotImplementedError ()
75
+
76
+
77
+ class SimpleArrayDiff (BaseDiff ):
78
+
79
+ @classmethod
80
+ def compare (cls , reference_file , test_file , atol = None , rtol = None ):
81
+
82
+ array_ref = cls .read (reference_file )
83
+ array_new = cls .read (test_file )
84
+
85
+ try :
86
+ np .testing .assert_allclose (array_ref , array_new , atol = atol , rtol = rtol )
87
+ except AssertionError as exc :
88
+ message = "\n \n a: {0}" .format (test_file ) + '\n '
89
+ message += "b: {0}" .format (reference_file ) + '\n '
90
+ message += exc .args [0 ]
91
+ return False , message
92
+ else :
93
+ return True , ""
47
94
48
95
49
- class FITSDiff (object ):
96
+ class FITSDiff (BaseDiff ):
50
97
51
98
extension = 'fits'
52
99
@@ -56,12 +103,20 @@ def read(filename):
56
103
return fits .getdata (filename )
57
104
58
105
@staticmethod
59
- def write (filename , array , ** kwargs ):
106
+ def write (filename , data , ** kwargs ):
60
107
from astropy .io import fits
61
- return fits .writeto (filename , array , ** kwargs )
108
+ if isinstance (data , np .ndarray ):
109
+ data = fits .PrimaryHDU (data )
110
+ return data .writeto (filename , ** kwargs )
111
+
112
+ @classmethod
113
+ def compare (cls , reference_file , test_file , atol = None , rtol = None ):
114
+ from astropy .io .fits .diff import FITSDiff
115
+ diff = FITSDiff (reference_file , test_file , tolerance = rtol )
116
+ return diff .identical , diff .report ()
62
117
63
118
64
- class TextDiff (object ):
119
+ class TextDiff (SimpleArrayDiff ):
65
120
66
121
extension = 'txt'
67
122
@@ -70,10 +125,10 @@ def read(filename):
70
125
return np .loadtxt (filename )
71
126
72
127
@staticmethod
73
- def write (filename , array , ** kwargs ):
128
+ def write (filename , data , ** kwargs ):
74
129
if 'fmt' not in kwargs :
75
130
kwargs ['fmt' ] = '%g'
76
- return np .savetxt (filename , array , ** kwargs )
131
+ return np .savetxt (filename , data , ** kwargs )
77
132
78
133
79
134
FORMATS = {}
@@ -219,17 +274,12 @@ def item_function_wrapper(*args, **kwargs):
219
274
baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
220
275
shutil .copyfile (baseline_file_ref , baseline_file )
221
276
222
- array_ref = FORMATS [file_format ].read (baseline_file )
277
+ identical , msg = FORMATS [file_format ].compare (baseline_file , test_image , atol = atol , rtol = rtol )
223
278
224
- try :
225
- np .testing .assert_allclose (array_ref , array , atol = atol , rtol = rtol )
226
- except AssertionError as exc :
227
- message = "\n \n a: {0}" .format (test_image ) + '\n '
228
- message += "b: {0}" .format (baseline_file ) + '\n '
229
- message += exc .args [0 ]
230
- raise AssertionError (message )
231
-
232
- shutil .rmtree (result_dir )
279
+ if identical :
280
+ shutil .rmtree (result_dir )
281
+ else :
282
+ raise Exception (msg )
233
283
234
284
else :
235
285
0 commit comments