32
32
from functools import wraps
33
33
34
34
import os
35
- import sys
35
+ import abc
36
36
import shutil
37
37
import tempfile
38
38
import warnings
39
39
40
+ from six import add_metaclass , PY2
41
+ from six .moves .urllib .request import urlopen
42
+
40
43
import pytest
41
44
import numpy as np
42
45
43
- if sys .version_info [0 ] == 2 :
44
- from urllib import urlopen
46
+
47
+ if PY2 :
48
+ def abstractstaticmethod (func ):
49
+ return func
50
+ def abstractclassmethod (func ):
51
+ return func
45
52
else :
46
- from urllib .request import urlopen
53
+ abstractstaticmethod = abc .abstractstaticmethod
54
+ abstractclassmethod = abc .abstractclassmethod
55
+
56
+
57
+ @add_metaclass (abc .ABCMeta )
58
+ class BaseDiff (object ):
59
+
60
+ @abstractstaticmethod
61
+ def read (filename ):
62
+ """
63
+ Given a filename, return a data object.
64
+ """
65
+ raise NotImplementedError ()
66
+
67
+ @abstractstaticmethod
68
+ def write (filename , data , ** kwargs ):
69
+ """
70
+ Given a filename and a data object (and optional keyword arguments),
71
+ write the data to a file.
72
+ """
73
+ raise NotImplementedError ()
74
+
75
+ @abstractclassmethod
76
+ def compare (self , reference_file , test_file , atol = None , rtol = None ):
77
+ """
78
+ Given a reference and test filename, compare the data to the specified
79
+ absolute (``atol``) and relative (``rtol``) tolerances.
80
+
81
+ Should return two arguments: a boolean indicating whether the data are
82
+ identical, and a string giving the full error message if not.
83
+ """
84
+ raise NotImplementedError ()
85
+
86
+
87
+ class SimpleArrayDiff (BaseDiff ):
88
+
89
+ @classmethod
90
+ def compare (cls , reference_file , test_file , atol = None , rtol = None ):
91
+
92
+ array_ref = cls .read (reference_file )
93
+ array_new = cls .read (test_file )
94
+
95
+ try :
96
+ np .testing .assert_allclose (array_ref , array_new , atol = atol , rtol = rtol )
97
+ except AssertionError as exc :
98
+ message = "\n \n a: {0}" .format (test_file ) + '\n '
99
+ message += "b: {0}" .format (reference_file ) + '\n '
100
+ message += exc .args [0 ]
101
+ return False , message
102
+ else :
103
+ return True , ""
47
104
48
105
49
- class FITSDiff (object ):
106
+ class FITSDiff (BaseDiff ):
50
107
51
108
extension = 'fits'
52
109
@@ -56,12 +113,20 @@ def read(filename):
56
113
return fits .getdata (filename )
57
114
58
115
@staticmethod
59
- def write (filename , array , ** kwargs ):
116
+ def write (filename , data , ** kwargs ):
60
117
from astropy .io import fits
61
- return fits .writeto (filename , array , ** kwargs )
118
+ if isinstance (data , np .ndarray ):
119
+ data = fits .PrimaryHDU (data )
120
+ return data .writeto (filename , ** kwargs )
121
+
122
+ @classmethod
123
+ def compare (cls , reference_file , test_file , atol = None , rtol = None ):
124
+ from astropy .io .fits .diff import FITSDiff
125
+ diff = FITSDiff (reference_file , test_file , tolerance = rtol )
126
+ return diff .identical , diff .report ()
62
127
63
128
64
- class TextDiff (object ):
129
+ class TextDiff (SimpleArrayDiff ):
65
130
66
131
extension = 'txt'
67
132
@@ -70,10 +135,10 @@ def read(filename):
70
135
return np .loadtxt (filename )
71
136
72
137
@staticmethod
73
- def write (filename , array , ** kwargs ):
138
+ def write (filename , data , ** kwargs ):
74
139
if 'fmt' not in kwargs :
75
140
kwargs ['fmt' ] = '%g'
76
- return np .savetxt (filename , array , ** kwargs )
141
+ return np .savetxt (filename , data , ** kwargs )
77
142
78
143
79
144
FORMATS = {}
@@ -219,17 +284,12 @@ def item_function_wrapper(*args, **kwargs):
219
284
baseline_file = os .path .abspath (os .path .join (result_dir , 'reference-' + filename ))
220
285
shutil .copyfile (baseline_file_ref , baseline_file )
221
286
222
- array_ref = FORMATS [file_format ].read (baseline_file )
287
+ identical , msg = FORMATS [file_format ].compare (baseline_file , test_image , atol = atol , rtol = rtol )
223
288
224
- try :
225
- np .testing .assert_allclose (array_ref , array , atol = atol , rtol = rtol )
226
- except AssertionError as exc :
227
- message = "\n \n a: {0}" .format (test_image ) + '\n '
228
- message += "b: {0}" .format (baseline_file ) + '\n '
229
- message += exc .args [0 ]
230
- raise AssertionError (message )
231
-
232
- shutil .rmtree (result_dir )
289
+ if identical :
290
+ shutil .rmtree (result_dir )
291
+ else :
292
+ raise Exception (msg )
233
293
234
294
else :
235
295
0 commit comments