|
| 1 | +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- |
| 2 | +# vi: set ft=python sts=4 ts=4 sw=4 et: |
| 3 | +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## |
| 4 | +# |
| 5 | +# See COPYING file distributed along with the NiBabel package for the |
| 6 | +# copyright and license terms. |
| 7 | +# |
| 8 | +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## |
| 9 | +''' Utilities for testing ''' |
| 10 | + |
| 11 | +import re |
| 12 | +import os |
| 13 | +import sys |
| 14 | +import warnings |
| 15 | +from pkg_resources import resource_filename |
| 16 | +from os.path import dirname, abspath, join as pjoin |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +from numpy.testing import assert_array_equal, assert_warns |
| 20 | +from numpy.testing import dec |
| 21 | +skipif = dec.skipif |
| 22 | +slow = dec.slow |
| 23 | + |
| 24 | +from ..deprecated import deprecate_with_version as _deprecate_with_version |
| 25 | + |
| 26 | + |
| 27 | +from itertools import zip_longest |
| 28 | + |
| 29 | + |
| 30 | +def test_data(subdir=None, fname=None): |
| 31 | + if subdir is None: |
| 32 | + resource = os.path.join('tests', 'data') |
| 33 | + elif subdir in ('gifti', 'nicom', 'externals'): |
| 34 | + resource = os.path.join(subdir, 'tests', 'data') |
| 35 | + else: |
| 36 | + raise ValueError("Unknown test data directory: %s" % subdir) |
| 37 | + |
| 38 | + if fname is not None: |
| 39 | + resource = os.path.join(resource, fname) |
| 40 | + |
| 41 | + return resource_filename('nibabel', resource) |
| 42 | + |
| 43 | + |
| 44 | +# set path to example data |
| 45 | +data_path = test_data() |
| 46 | + |
| 47 | + |
| 48 | +from .np_features import memmap_after_ufunc |
| 49 | + |
| 50 | +def assert_dt_equal(a, b): |
| 51 | + """ Assert two numpy dtype specifiers are equal |
| 52 | +
|
| 53 | + Avoids failed comparison between int32 / int64 and intp |
| 54 | + """ |
| 55 | + assert np.dtype(a).str == np.dtype(b).str |
| 56 | + |
| 57 | + |
| 58 | +def assert_allclose_safely(a, b, match_nans=True, rtol=1e-5, atol=1e-8): |
| 59 | + """ Allclose in integers go all wrong for large integers |
| 60 | + """ |
| 61 | + a = np.atleast_1d(a) # 0d arrays cannot be indexed |
| 62 | + a, b = np.broadcast_arrays(a, b) |
| 63 | + if match_nans: |
| 64 | + nans = np.isnan(a) |
| 65 | + np.testing.assert_array_equal(nans, np.isnan(b)) |
| 66 | + to_test = ~nans |
| 67 | + else: |
| 68 | + to_test = np.ones(a.shape, dtype=bool) |
| 69 | + # Deal with float128 inf comparisons (bug in numpy 1.9.2) |
| 70 | + # np.allclose(np.float128(np.inf), np.float128(np.inf)) == False |
| 71 | + to_test = to_test & (a != b) |
| 72 | + a = a[to_test] |
| 73 | + b = b[to_test] |
| 74 | + if a.dtype.kind in 'ui': |
| 75 | + a = a.astype(float) |
| 76 | + if b.dtype.kind in 'ui': |
| 77 | + b = b.astype(float) |
| 78 | + assert np.allclose(a, b, rtol=rtol, atol=atol) |
| 79 | + |
| 80 | + |
| 81 | +def assert_arrays_equal(arrays1, arrays2): |
| 82 | + """ Check two iterables yield the same sequence of arrays. """ |
| 83 | + for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): |
| 84 | + assert (arr1 is not None and arr2 is not None) |
| 85 | + assert_array_equal(arr1, arr2) |
| 86 | + |
| 87 | + |
| 88 | +def assert_re_in(regex, c, flags=0): |
| 89 | + """Assert that container (list, str, etc) contains entry matching the regex |
| 90 | + """ |
| 91 | + if not isinstance(c, (list, tuple)): |
| 92 | + c = [c] |
| 93 | + for e in c: |
| 94 | + if re.match(regex, e, flags=flags): |
| 95 | + return |
| 96 | + raise AssertionError("Not a single entry matched %r in %r" % (regex, c)) |
| 97 | + |
| 98 | + |
| 99 | +def get_fresh_mod(mod_name=__name__): |
| 100 | + # Get this module, with warning registry empty |
| 101 | + my_mod = sys.modules[mod_name] |
| 102 | + try: |
| 103 | + my_mod.__warningregistry__.clear() |
| 104 | + except AttributeError: |
| 105 | + pass |
| 106 | + return my_mod |
| 107 | + |
| 108 | + |
| 109 | +class clear_and_catch_warnings(warnings.catch_warnings): |
| 110 | + """ Context manager that resets warning registry for catching warnings |
| 111 | +
|
| 112 | + Warnings can be slippery, because, whenever a warning is triggered, Python |
| 113 | + adds a ``__warningregistry__`` member to the *calling* module. This makes |
| 114 | + it impossible to retrigger the warning in this module, whatever you put in |
| 115 | + the warnings filters. This context manager accepts a sequence of `modules` |
| 116 | + as a keyword argument to its constructor and: |
| 117 | +
|
| 118 | + * stores and removes any ``__warningregistry__`` entries in given `modules` |
| 119 | + on entry; |
| 120 | + * resets ``__warningregistry__`` to its previous state on exit. |
| 121 | +
|
| 122 | + This makes it possible to trigger any warning afresh inside the context |
| 123 | + manager without disturbing the state of warnings outside. |
| 124 | +
|
| 125 | + For compatibility with Python 3.0, please consider all arguments to be |
| 126 | + keyword-only. |
| 127 | +
|
| 128 | + Parameters |
| 129 | + ---------- |
| 130 | + record : bool, optional |
| 131 | + Specifies whether warnings should be captured by a custom |
| 132 | + implementation of ``warnings.showwarning()`` and be appended to a list |
| 133 | + returned by the context manager. Otherwise None is returned by the |
| 134 | + context manager. The objects appended to the list are arguments whose |
| 135 | + attributes mirror the arguments to ``showwarning()``. |
| 136 | +
|
| 137 | + NOTE: nibabel difference from numpy: default is True |
| 138 | +
|
| 139 | + modules : sequence, optional |
| 140 | + Sequence of modules for which to reset warnings registry on entry and |
| 141 | + restore on exit |
| 142 | +
|
| 143 | + Examples |
| 144 | + -------- |
| 145 | + >>> import warnings |
| 146 | + >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]): |
| 147 | + ... warnings.simplefilter('always') |
| 148 | + ... # do something that raises a warning in np.core.fromnumeric |
| 149 | + """ |
| 150 | + class_modules = () |
| 151 | + |
| 152 | + def __init__(self, record=True, modules=()): |
| 153 | + self.modules = set(modules).union(self.class_modules) |
| 154 | + self._warnreg_copies = {} |
| 155 | + super(clear_and_catch_warnings, self).__init__(record=record) |
| 156 | + |
| 157 | + def __enter__(self): |
| 158 | + for mod in self.modules: |
| 159 | + if hasattr(mod, '__warningregistry__'): |
| 160 | + mod_reg = mod.__warningregistry__ |
| 161 | + self._warnreg_copies[mod] = mod_reg.copy() |
| 162 | + mod_reg.clear() |
| 163 | + return super(clear_and_catch_warnings, self).__enter__() |
| 164 | + |
| 165 | + def __exit__(self, *exc_info): |
| 166 | + super(clear_and_catch_warnings, self).__exit__(*exc_info) |
| 167 | + for mod in self.modules: |
| 168 | + if hasattr(mod, '__warningregistry__'): |
| 169 | + mod.__warningregistry__.clear() |
| 170 | + if mod in self._warnreg_copies: |
| 171 | + mod.__warningregistry__.update(self._warnreg_copies[mod]) |
| 172 | + |
| 173 | + |
| 174 | +class error_warnings(clear_and_catch_warnings): |
| 175 | + """ Context manager to check for warnings as errors. Usually used with |
| 176 | + ``assert_raises`` in the with block |
| 177 | +
|
| 178 | + Examples |
| 179 | + -------- |
| 180 | + >>> with error_warnings(): |
| 181 | + ... try: |
| 182 | + ... warnings.warn('Message', UserWarning) |
| 183 | + ... except UserWarning: |
| 184 | + ... print('I consider myself warned') |
| 185 | + I consider myself warned |
| 186 | + """ |
| 187 | + filter = 'error' |
| 188 | + |
| 189 | + def __enter__(self): |
| 190 | + mgr = super(error_warnings, self).__enter__() |
| 191 | + warnings.simplefilter(self.filter) |
| 192 | + return mgr |
| 193 | + |
| 194 | + |
| 195 | +class suppress_warnings(error_warnings): |
| 196 | + """ Version of ``catch_warnings`` class that suppresses warnings |
| 197 | + """ |
| 198 | + filter = 'ignore' |
| 199 | + |
| 200 | + |
| 201 | +@_deprecate_with_version('catch_warn_reset is deprecated; use ' |
| 202 | + 'nibabel.testing.clear_and_catch_warnings.', |
| 203 | + since='2.1.0', until='3.0.0') |
| 204 | +class catch_warn_reset(clear_and_catch_warnings): |
| 205 | + pass |
| 206 | + |
| 207 | + |
| 208 | +EXTRA_SET = os.environ.get('NIPY_EXTRA_TESTS', '').split(',') |
| 209 | + |
| 210 | + |
| 211 | +def runif_extra_has(test_str): |
| 212 | + """Decorator checks to see if NIPY_EXTRA_TESTS env var contains test_str""" |
| 213 | + return skipif(test_str not in EXTRA_SET, |
| 214 | + "Skip {0} tests.".format(test_str)) |
| 215 | + |
| 216 | + |
| 217 | +def assert_arr_dict_equal(dict1, dict2): |
| 218 | + """ Assert that two dicts are equal, where dicts contain arrays |
| 219 | + """ |
| 220 | + assert set(dict1) == set(dict2) |
| 221 | + for key, value1 in dict1.items(): |
| 222 | + value2 = dict2[key] |
| 223 | + assert_array_equal(value1, value2) |
0 commit comments