From 04e365cfbdc7992fbd2772107726706942aae374 Mon Sep 17 00:00:00 2001 From: Conor MacBride Date: Tue, 25 Oct 2022 22:51:09 +0100 Subject: [PATCH 1/4] Test inside `pytest_runtest_call` hook --- pytest_arraydiff/plugin.py | 170 +++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 74 deletions(-) diff --git a/pytest_arraydiff/plugin.py b/pytest_arraydiff/plugin.py index da78dfb..ae44e7f 100755 --- a/pytest_arraydiff/plugin.py +++ b/pytest_arraydiff/plugin.py @@ -223,6 +223,35 @@ def pytest_configure(config): default_format=default_format)) +def generate_test_name(item): + """ + Generate a unique name for this test. + """ + if item.cls is not None: + name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}" + else: + name = f"{item.module.__name__}.{item.name}" + return name + + +def wrap_array_interceptor(plugin, item): + """ + Intercept and store arrays returned by test functions. + """ + # Only intercept array on marked array tests + if item.get_closest_marker('array_compare') is not None: + + # Use the full test name as a key to ensure correct array is being retrieved + test_name = generate_test_name(item) + + def array_interceptor(store, obj): + def wrapper(*args, **kwargs): + store.return_value[test_name] = obj(*args, **kwargs) + return wrapper + + item.obj = array_interceptor(plugin, item.obj) + + class ArrayComparison(object): def __init__(self, config, reference_dir=None, generate_dir=None, default_format='text'): @@ -230,12 +259,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format self.reference_dir = reference_dir self.generate_dir = generate_dir self.default_format = default_format + self.return_value = {} - def pytest_runtest_setup(self, item): + @pytest.hookimpl(hookwrapper=True) + def pytest_runtest_call(self, item): compare = item.get_closest_marker('array_compare') if compare is None: + yield return file_format = compare.kwargs.get('file_format', self.default_format) @@ -255,85 +287,75 @@ def pytest_runtest_setup(self, item): write_kwargs = compare.kwargs.get('write_kwargs', {}) - original = item.function + reference_dir = compare.kwargs.get('reference_dir', None) + if reference_dir is None: + if self.reference_dir is None: + reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference') + else: + reference_dir = self.reference_dir + else: + if not reference_dir.startswith(('http://', 'https://')): + reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir) - @wraps(item.function) - def item_function_wrapper(*args, **kwargs): + baseline_remote = reference_dir.startswith('http') - reference_dir = compare.kwargs.get('reference_dir', None) - if reference_dir is None: - if self.reference_dir is None: - reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference') - else: - reference_dir = self.reference_dir - else: - if not reference_dir.startswith(('http://', 'https://')): - reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir) - - baseline_remote = reference_dir.startswith('http') - - # Run test and get figure object - import inspect - if inspect.ismethod(original): # method - array = original(*args[1:], **kwargs) - else: # function - array = original(*args, **kwargs) - - # Find test name to use as plot name - filename = compare.kwargs.get('filename', None) - if filename is None: - if single_reference: - filename = original.__name__ + '.' + extension - else: - filename = item.name + '.' + extension - filename = filename.replace('[', '_').replace(']', '_') - filename = filename.replace('_.' + extension, '.' + extension) - - # What we do now depends on whether we are generating the reference - # files or simply running the test. - if self.generate_dir is None: - - # Save the figure - result_dir = tempfile.mkdtemp() - test_array = os.path.abspath(os.path.join(result_dir, filename)) - - FORMATS[file_format].write(test_array, array, **write_kwargs) - - # Find path to baseline array - if baseline_remote: - baseline_file_ref = _download_file(reference_dir + filename) - else: - baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename)) - - if not os.path.exists(baseline_file_ref): - raise Exception("""File not found for comparison test - Generated file: - \t{test} - This is expected for new tests.""".format( - test=test_array)) - - # setuptools may put the baseline arrays in non-accessible places, - # copy to our tmpdir to be sure to keep them in case of failure - baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename)) - shutil.copyfile(baseline_file_ref, baseline_file) - - identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol) - - if identical: - shutil.rmtree(result_dir) - else: - raise Exception(msg) + # Run test and get array object + wrap_array_interceptor(self, item) + yield + test_name = generate_test_name(item) + if test_name not in self.return_value: + # Test function did not complete successfully + return + array = self.return_value[test_name] + + # Find test name to use as plot name + filename = compare.kwargs.get('filename', None) + if filename is None: + filename = item.name + '.' + extension + if not single_reference: + filename = filename.replace('[', '_').replace(']', '_') + filename = filename.replace('_.' + extension, '.' + extension) + + # What we do now depends on whether we are generating the reference + # files or simply running the test. + if self.generate_dir is None: + + # Save the figure + result_dir = tempfile.mkdtemp() + test_array = os.path.abspath(os.path.join(result_dir, filename)) + + FORMATS[file_format].write(test_array, array, **write_kwargs) + # Find path to baseline array + if baseline_remote: + baseline_file_ref = _download_file(reference_dir + filename) else: + baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename)) - if not os.path.exists(self.generate_dir): - os.makedirs(self.generate_dir) + if not os.path.exists(baseline_file_ref): + raise Exception("""File not found for comparison test + Generated file: + \t{test} + This is expected for new tests.""".format( + test=test_array)) - FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs) + # setuptools may put the baseline arrays in non-accessible places, + # copy to our tmpdir to be sure to keep them in case of failure + baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename)) + shutil.copyfile(baseline_file_ref, baseline_file) - pytest.skip("Skipping test, since generating data") + identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol) + + if identical: + shutil.rmtree(result_dir) + else: + raise Exception(msg) - if item.cls is not None: - setattr(item.cls, item.function.__name__, item_function_wrapper) else: - item.obj = item_function_wrapper + + if not os.path.exists(self.generate_dir): + os.makedirs(self.generate_dir) + + FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs) + + pytest.skip("Skipping test, since generating data") From 864bad2852aa0d42ba9448d2efa13c2a3903fb4f Mon Sep 17 00:00:00 2001 From: Conor MacBride Date: Tue, 25 Oct 2022 23:27:20 +0100 Subject: [PATCH 2/4] If arraydiff is not enabled, still intercept the return value This is to prevent a pytest warning or error --- pytest_arraydiff/plugin.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pytest_arraydiff/plugin.py b/pytest_arraydiff/plugin.py index ae44e7f..bb7c07e 100755 --- a/pytest_arraydiff/plugin.py +++ b/pytest_arraydiff/plugin.py @@ -221,6 +221,8 @@ def pytest_configure(config): reference_dir=reference_dir, generate_dir=generate_dir, default_format=default_format)) + else: + config.pluginmanager.register(ArrayInterceptor(config)) def generate_test_name(item): @@ -359,3 +361,23 @@ def pytest_runtest_call(self, item): FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs) pytest.skip("Skipping test, since generating data") + + +class ArrayInterceptor: + """ + This is used in place of ArrayComparison when the array comparison option is not used, + to make sure that we still intercept arrays returned by tests. + """ + + def __init__(self, config): + self.config = config + self.return_value = {} + + @pytest.hookimpl(hookwrapper=True) + def pytest_runtest_call(self, item): + + if item.get_closest_marker('array_compare') is not None: + wrap_array_interceptor(self, item) + + yield + return From 35422474b726018b435a17818ed5bf94559fc3f1 Mon Sep 17 00:00:00 2001 From: Conor MacBride Date: Tue, 25 Oct 2022 23:27:53 +0100 Subject: [PATCH 3/4] Upgrade warnings to errors when testing --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index 674af9c..e7d2549 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,8 @@ testpaths = tests xfail_strict = true markers = array_compare: for functions using array comparison +filterwarnings = + error [flake8] max-line-length = 150 From 9ffd246a761586a7e50b436fa0322d0d55e66d88 Mon Sep 17 00:00:00 2001 From: Conor MacBride Date: Tue, 25 Oct 2022 23:40:21 +0100 Subject: [PATCH 4/4] =?UTF-8?q?Ignore=20warning=20=E2=80=9Cdistutils=20Ver?= =?UTF-8?q?sion=20classes=20are=20deprecated=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index e7d2549..5bc3a65 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,8 @@ markers = array_compare: for functions using array comparison filterwarnings = error + # Can be removed when min Python is >=3.8 + ignore:distutils Version classes are deprecated [flake8] max-line-length = 150