Skip to content

Commit e006ea1

Browse files
authored
Merge pull request #36 from ConorMacBride/update-pytest-integration
Test inside `pytest_runtest_call` hook
2 parents 1af2acd + 9ffd246 commit e006ea1

File tree

2 files changed

+122
-74
lines changed

2 files changed

+122
-74
lines changed

pytest_arraydiff/plugin.py

Lines changed: 118 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,37 @@ def pytest_configure(config):
221221
reference_dir=reference_dir,
222222
generate_dir=generate_dir,
223223
default_format=default_format))
224+
else:
225+
config.pluginmanager.register(ArrayInterceptor(config))
226+
227+
228+
def generate_test_name(item):
229+
"""
230+
Generate a unique name for this test.
231+
"""
232+
if item.cls is not None:
233+
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
234+
else:
235+
name = f"{item.module.__name__}.{item.name}"
236+
return name
237+
238+
239+
def wrap_array_interceptor(plugin, item):
240+
"""
241+
Intercept and store arrays returned by test functions.
242+
"""
243+
# Only intercept array on marked array tests
244+
if item.get_closest_marker('array_compare') is not None:
245+
246+
# Use the full test name as a key to ensure correct array is being retrieved
247+
test_name = generate_test_name(item)
248+
249+
def array_interceptor(store, obj):
250+
def wrapper(*args, **kwargs):
251+
store.return_value[test_name] = obj(*args, **kwargs)
252+
return wrapper
253+
254+
item.obj = array_interceptor(plugin, item.obj)
224255

225256

226257
class ArrayComparison(object):
@@ -230,12 +261,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format
230261
self.reference_dir = reference_dir
231262
self.generate_dir = generate_dir
232263
self.default_format = default_format
264+
self.return_value = {}
233265

234-
def pytest_runtest_setup(self, item):
266+
@pytest.hookimpl(hookwrapper=True)
267+
def pytest_runtest_call(self, item):
235268

236269
compare = item.get_closest_marker('array_compare')
237270

238271
if compare is None:
272+
yield
239273
return
240274

241275
file_format = compare.kwargs.get('file_format', self.default_format)
@@ -255,85 +289,95 @@ def pytest_runtest_setup(self, item):
255289

256290
write_kwargs = compare.kwargs.get('write_kwargs', {})
257291

258-
original = item.function
292+
reference_dir = compare.kwargs.get('reference_dir', None)
293+
if reference_dir is None:
294+
if self.reference_dir is None:
295+
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
296+
else:
297+
reference_dir = self.reference_dir
298+
else:
299+
if not reference_dir.startswith(('http://', 'https://')):
300+
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)
259301

260-
@wraps(item.function)
261-
def item_function_wrapper(*args, **kwargs):
302+
baseline_remote = reference_dir.startswith('http')
262303

263-
reference_dir = compare.kwargs.get('reference_dir', None)
264-
if reference_dir is None:
265-
if self.reference_dir is None:
266-
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
267-
else:
268-
reference_dir = self.reference_dir
269-
else:
270-
if not reference_dir.startswith(('http://', 'https://')):
271-
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)
272-
273-
baseline_remote = reference_dir.startswith('http')
274-
275-
# Run test and get figure object
276-
import inspect
277-
if inspect.ismethod(original): # method
278-
array = original(*args[1:], **kwargs)
279-
else: # function
280-
array = original(*args, **kwargs)
281-
282-
# Find test name to use as plot name
283-
filename = compare.kwargs.get('filename', None)
284-
if filename is None:
285-
if single_reference:
286-
filename = original.__name__ + '.' + extension
287-
else:
288-
filename = item.name + '.' + extension
289-
filename = filename.replace('[', '_').replace(']', '_')
290-
filename = filename.replace('_.' + extension, '.' + extension)
291-
292-
# What we do now depends on whether we are generating the reference
293-
# files or simply running the test.
294-
if self.generate_dir is None:
295-
296-
# Save the figure
297-
result_dir = tempfile.mkdtemp()
298-
test_array = os.path.abspath(os.path.join(result_dir, filename))
299-
300-
FORMATS[file_format].write(test_array, array, **write_kwargs)
301-
302-
# Find path to baseline array
303-
if baseline_remote:
304-
baseline_file_ref = _download_file(reference_dir + filename)
305-
else:
306-
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))
307-
308-
if not os.path.exists(baseline_file_ref):
309-
raise Exception("""File not found for comparison test
310-
Generated file:
311-
\t{test}
312-
This is expected for new tests.""".format(
313-
test=test_array))
314-
315-
# setuptools may put the baseline arrays in non-accessible places,
316-
# copy to our tmpdir to be sure to keep them in case of failure
317-
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
318-
shutil.copyfile(baseline_file_ref, baseline_file)
319-
320-
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)
321-
322-
if identical:
323-
shutil.rmtree(result_dir)
324-
else:
325-
raise Exception(msg)
304+
# Run test and get array object
305+
wrap_array_interceptor(self, item)
306+
yield
307+
test_name = generate_test_name(item)
308+
if test_name not in self.return_value:
309+
# Test function did not complete successfully
310+
return
311+
array = self.return_value[test_name]
312+
313+
# Find test name to use as plot name
314+
filename = compare.kwargs.get('filename', None)
315+
if filename is None:
316+
filename = item.name + '.' + extension
317+
if not single_reference:
318+
filename = filename.replace('[', '_').replace(']', '_')
319+
filename = filename.replace('_.' + extension, '.' + extension)
320+
321+
# What we do now depends on whether we are generating the reference
322+
# files or simply running the test.
323+
if self.generate_dir is None:
324+
325+
# Save the figure
326+
result_dir = tempfile.mkdtemp()
327+
test_array = os.path.abspath(os.path.join(result_dir, filename))
326328

329+
FORMATS[file_format].write(test_array, array, **write_kwargs)
330+
331+
# Find path to baseline array
332+
if baseline_remote:
333+
baseline_file_ref = _download_file(reference_dir + filename)
327334
else:
335+
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))
336+
337+
if not os.path.exists(baseline_file_ref):
338+
raise Exception("""File not found for comparison test
339+
Generated file:
340+
\t{test}
341+
This is expected for new tests.""".format(
342+
test=test_array))
328343

329-
if not os.path.exists(self.generate_dir):
330-
os.makedirs(self.generate_dir)
344+
# setuptools may put the baseline arrays in non-accessible places,
345+
# copy to our tmpdir to be sure to keep them in case of failure
346+
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
347+
shutil.copyfile(baseline_file_ref, baseline_file)
331348

332-
FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
349+
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)
333350

334-
pytest.skip("Skipping test, since generating data")
351+
if identical:
352+
shutil.rmtree(result_dir)
353+
else:
354+
raise Exception(msg)
335355

336-
if item.cls is not None:
337-
setattr(item.cls, item.function.__name__, item_function_wrapper)
338356
else:
339-
item.obj = item_function_wrapper
357+
358+
if not os.path.exists(self.generate_dir):
359+
os.makedirs(self.generate_dir)
360+
361+
FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
362+
363+
pytest.skip("Skipping test, since generating data")
364+
365+
366+
class ArrayInterceptor:
367+
"""
368+
This is used in place of ArrayComparison when the array comparison option is not used,
369+
to make sure that we still intercept arrays returned by tests.
370+
"""
371+
372+
def __init__(self, config):
373+
self.config = config
374+
self.return_value = {}
375+
376+
@pytest.hookimpl(hookwrapper=True)
377+
def pytest_runtest_call(self, item):
378+
379+
if item.get_closest_marker('array_compare') is not None:
380+
wrap_array_interceptor(self, item)
381+
382+
yield
383+
return

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ testpaths = tests
5050
xfail_strict = true
5151
markers =
5252
array_compare: for functions using array comparison
53+
filterwarnings =
54+
error
55+
# Can be removed when min Python is >=3.8
56+
ignore:distutils Version classes are deprecated
5357

5458
[flake8]
5559
max-line-length = 150

0 commit comments

Comments
 (0)