Skip to content

Commit 83359e6

Browse files
committed
Use pytest_pyfunc_call hook to intercept figure
Instead modifying the test function itself by wrapping it in a function which runs the tests, use pytest hooks to intercept the generated figure and then run the tests. This should be a more robust approach that doesn't need as many special cases to be hardcoded. I have also refactored the get_marker and get_compare functions/methods to simplify these.
1 parent cdb37c9 commit 83359e6

File tree

1 file changed

+103
-128
lines changed

1 file changed

+103
-128
lines changed

pytest_mpl/plugin.py

Lines changed: 103 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def pathify(path):
8383
return Path(path + ext)
8484

8585

86+
def _pytest_pyfunc_call(obj, pyfuncitem):
87+
testfunction = pyfuncitem.obj
88+
funcargs = pyfuncitem.funcargs
89+
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
90+
obj.result = testfunction(**testargs)
91+
return True
92+
93+
8694
def pytest_report_header(config, startdir):
8795
import matplotlib
8896
import matplotlib.ft2font
@@ -211,13 +219,11 @@ def close_mpl_figure(fig):
211219
plt.close(fig)
212220

213221

214-
def get_marker(item, marker_name):
215-
if hasattr(item, 'get_closest_marker'):
216-
return item.get_closest_marker(marker_name)
217-
else:
218-
# "item.keywords.get" was deprecated in pytest 3.6
219-
# See https://docs.pytest.org/en/latest/mark.html#updating-code
220-
return item.keywords.get(marker_name)
222+
def get_compare(item):
223+
"""
224+
Return the mpl_image_compare marker for the given item.
225+
"""
226+
return item.get_closest_marker("mpl_image_compare")
221227

222228

223229
def path_is_not_none(apath):
@@ -278,20 +284,14 @@ def __init__(self,
278284
logging.basicConfig(level=level)
279285
self.logger = logging.getLogger('pytest-mpl')
280286

281-
def get_compare(self, item):
282-
"""
283-
Return the mpl_image_compare marker for the given item.
284-
"""
285-
return get_marker(item, 'mpl_image_compare')
286-
287287
def generate_filename(self, item):
288288
"""
289289
Given a pytest item, generate the figure filename.
290290
"""
291291
if self.config.getini('mpl-use-full-test-name'):
292292
filename = self.generate_test_name(item) + '.png'
293293
else:
294-
compare = self.get_compare(item)
294+
compare = get_compare(item)
295295
# Find test name to use as plot name
296296
filename = compare.kwargs.get('filename', None)
297297
if filename is None:
@@ -319,7 +319,7 @@ def baseline_directory_specified(self, item):
319319
"""
320320
Returns `True` if a non-default baseline directory is specified.
321321
"""
322-
compare = self.get_compare(item)
322+
compare = get_compare(item)
323323
item_baseline_dir = compare.kwargs.get('baseline_dir', None)
324324
return item_baseline_dir or self.baseline_dir or self.baseline_relative_dir
325325

@@ -330,7 +330,7 @@ def get_baseline_directory(self, item):
330330
Using the global and per-test configuration return the absolute
331331
baseline dir, if the baseline file is local else return base URL.
332332
"""
333-
compare = self.get_compare(item)
333+
compare = get_compare(item)
334334
baseline_dir = compare.kwargs.get('baseline_dir', None)
335335
if baseline_dir is None:
336336
if self.baseline_dir is None:
@@ -394,7 +394,7 @@ def generate_baseline_image(self, item, fig):
394394
"""
395395
Generate reference figures.
396396
"""
397-
compare = self.get_compare(item)
397+
compare = get_compare(item)
398398
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
399399

400400
if not os.path.exists(self.generate_dir):
@@ -413,7 +413,7 @@ def generate_image_hash(self, item, fig):
413413
For a `matplotlib.figure.Figure`, returns the SHA256 hash as a hexadecimal
414414
string.
415415
"""
416-
compare = self.get_compare(item)
416+
compare = get_compare(item)
417417
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
418418

419419
imgdata = io.BytesIO()
@@ -436,7 +436,7 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
436436
if summary is None:
437437
summary = {}
438438

439-
compare = self.get_compare(item)
439+
compare = get_compare(item)
440440
tolerance = compare.kwargs.get('tolerance', 2)
441441
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
442442

@@ -510,7 +510,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
510510
if summary is None:
511511
summary = {}
512512

513-
compare = self.get_compare(item)
513+
compare = get_compare(item)
514514
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
515515

516516
if not self.results_hash_library_name:
@@ -582,11 +582,13 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
582582
return
583583
return summary['status_msg']
584584

585-
def pytest_runtest_setup(self, item): # noqa
585+
@pytest.hookimpl(hookwrapper=True)
586+
def pytest_runtest_call(self, item): # noqa
586587

587-
compare = self.get_compare(item)
588+
compare = get_compare(item)
588589

589590
if compare is None:
591+
yield
590592
return
591593

592594
import matplotlib.pyplot as plt
@@ -600,95 +602,82 @@ def pytest_runtest_setup(self, item): # noqa
600602
remove_text = compare.kwargs.get('remove_text', False)
601603
backend = compare.kwargs.get('backend', 'agg')
602604

603-
original = item.function
604-
605-
@wraps(item.function)
606-
def item_function_wrapper(*args, **kwargs):
607-
608-
with plt.style.context(style, after_reset=True), switch_backend(backend):
609-
610-
# Run test and get figure object
611-
if inspect.ismethod(original): # method
612-
# In some cases, for example if setup_method is used,
613-
# original appears to belong to an instance of the test
614-
# class that is not the same as args[0], and args[0] is the
615-
# one that has the correct attributes set up from setup_method
616-
# so we ignore original.__self__ and use args[0] instead.
617-
fig = original.__func__(*args, **kwargs)
618-
else: # function
619-
fig = original(*args, **kwargs)
620-
621-
if remove_text:
622-
remove_ticks_and_titles(fig)
623-
624-
test_name = self.generate_test_name(item)
625-
result_dir = self.make_test_results_dir(item)
626-
627-
summary = {
628-
'status': None,
629-
'image_status': None,
630-
'hash_status': None,
631-
'status_msg': None,
632-
'baseline_image': None,
633-
'diff_image': None,
634-
'rms': None,
635-
'tolerance': None,
636-
'result_image': None,
637-
'baseline_hash': None,
638-
'result_hash': None,
639-
}
640-
641-
# What we do now depends on whether we are generating the
642-
# reference images or simply running the test.
643-
if self.generate_dir is not None:
644-
summary['status'] = 'skipped'
645-
summary['image_status'] = 'generated'
646-
summary['status_msg'] = 'Skipped test, since generating image.'
647-
generate_image = self.generate_baseline_image(item, fig)
648-
if self.results_always: # Make baseline image available in HTML
649-
result_image = (result_dir / "baseline.png").absolute()
650-
shutil.copy(generate_image, result_image)
651-
summary['baseline_image'] = \
652-
result_image.relative_to(self.results_dir).as_posix()
653-
654-
if self.generate_hash_library is not None:
655-
summary['hash_status'] = 'generated'
656-
image_hash = self.generate_image_hash(item, fig)
657-
self._generated_hash_library[test_name] = image_hash
658-
summary['baseline_hash'] = image_hash
659-
660-
# Only test figures if not generating images
661-
if self.generate_dir is None:
662-
# Compare to hash library
663-
if self.hash_library or compare.kwargs.get('hash_library', None):
664-
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)
665-
666-
# Compare against a baseline if specified
667-
else:
668-
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)
669-
670-
close_mpl_figure(fig)
671-
672-
if msg is None:
673-
if not self.results_always:
674-
shutil.rmtree(result_dir)
675-
for image_type in ['baseline_image', 'diff_image', 'result_image']:
676-
summary[image_type] = None # image no longer exists
677-
else:
678-
self._test_results[test_name] = summary
679-
pytest.fail(msg, pytrace=False)
605+
with plt.style.context(style, after_reset=True), switch_backend(backend):
606+
607+
# Run test and get figure object
608+
yield
609+
fig = self.result
610+
611+
if remove_text:
612+
remove_ticks_and_titles(fig)
613+
614+
test_name = self.generate_test_name(item)
615+
result_dir = self.make_test_results_dir(item)
616+
617+
summary = {
618+
'status': None,
619+
'image_status': None,
620+
'hash_status': None,
621+
'status_msg': None,
622+
'baseline_image': None,
623+
'diff_image': None,
624+
'rms': None,
625+
'tolerance': None,
626+
'result_image': None,
627+
'baseline_hash': None,
628+
'result_hash': None,
629+
}
630+
631+
# What we do now depends on whether we are generating the
632+
# reference images or simply running the test.
633+
if self.generate_dir is not None:
634+
summary['status'] = 'skipped'
635+
summary['image_status'] = 'generated'
636+
summary['status_msg'] = 'Skipped test, since generating image.'
637+
generate_image = self.generate_baseline_image(item, fig)
638+
if self.results_always: # Make baseline image available in HTML
639+
result_image = (result_dir / "baseline.png").absolute()
640+
shutil.copy(generate_image, result_image)
641+
summary['baseline_image'] = \
642+
result_image.relative_to(self.results_dir).as_posix()
643+
644+
if self.generate_hash_library is not None:
645+
summary['hash_status'] = 'generated'
646+
image_hash = self.generate_image_hash(item, fig)
647+
self._generated_hash_library[test_name] = image_hash
648+
summary['baseline_hash'] = image_hash
649+
650+
# Only test figures if not generating images
651+
if self.generate_dir is None:
652+
# Compare to hash library
653+
if self.hash_library or compare.kwargs.get('hash_library', None):
654+
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)
655+
656+
# Compare against a baseline if specified
657+
else:
658+
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)
680659

681660
close_mpl_figure(fig)
682661

683-
self._test_results[test_name] = summary
662+
if msg is None:
663+
if not self.results_always:
664+
shutil.rmtree(result_dir)
665+
for image_type in ['baseline_image', 'diff_image', 'result_image']:
666+
summary[image_type] = None # image no longer exists
667+
else:
668+
self._test_results[test_name] = summary
669+
pytest.fail(msg, pytrace=False)
670+
671+
close_mpl_figure(fig)
684672

685-
if summary['status'] == 'skipped':
686-
pytest.skip(summary['status_msg'])
673+
self._test_results[test_name] = summary
687674

688-
if item.cls is not None:
689-
setattr(item.cls, item.function.__name__, item_function_wrapper)
690-
else:
691-
item.obj = item_function_wrapper
675+
if summary['status'] == 'skipped':
676+
pytest.skip(summary['status_msg'])
677+
678+
@pytest.hookimpl(tryfirst=True)
679+
def pytest_pyfunc_call(self, pyfuncitem):
680+
return _pytest_pyfunc_call(self, pyfuncitem)
692681

693682
def generate_summary_json(self):
694683
json_file = self.results_dir / 'results.json'
@@ -742,26 +731,12 @@ class FigureCloser:
742731
def __init__(self, config):
743732
self.config = config
744733

745-
def pytest_runtest_setup(self, item):
746-
747-
compare = get_marker(item, 'mpl_image_compare')
748-
749-
if compare is None:
750-
return
751-
752-
original = item.function
753-
754-
@wraps(item.function)
755-
def item_function_wrapper(*args, **kwargs):
756-
757-
if inspect.ismethod(original): # method
758-
fig = original.__func__(*args, **kwargs)
759-
else: # function
760-
fig = original(*args, **kwargs)
761-
762-
close_mpl_figure(fig)
734+
@pytest.hookimpl(hookwrapper=True)
735+
def pytest_runtest_call(self, item):
736+
yield
737+
if get_compare(item) is not None:
738+
close_mpl_figure(self.result)
763739

764-
if item.cls is not None:
765-
setattr(item.cls, item.function.__name__, item_function_wrapper)
766-
else:
767-
item.obj = item_function_wrapper
740+
@pytest.hookimpl(tryfirst=True)
741+
def pytest_pyfunc_call(self, pyfuncitem):
742+
return _pytest_pyfunc_call(self, pyfuncitem)

0 commit comments

Comments
 (0)