-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Pytest assert functions #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytest assert functions #1147
Changes from all commits
f30d7cd
83ae3c6
25d09fb
1af5a50
f2e8209
994c3e2
95a99eb
a80829c
b09f9cc
9fa25dd
e7215b4
8bf4ae3
4dac537
e77cfa0
f0cb0ae
c1e59aa
378e713
91b3abc
dcf4f05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ nosetests.xml | |
# PyCharm and Vim | ||
.idea | ||
*.swp | ||
.DS_Store | ||
|
||
# xarray specific | ||
doc/_build | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
import pytest | ||
|
||
from xarray.core import utils, nputils, ops | ||
from xarray.core.variable import as_variable | ||
|
@@ -71,42 +72,41 @@ | |
except ImportError: | ||
has_bottleneck = False | ||
|
||
# slighly simpler construction that the full functions. | ||
# Generally `pytest.importorskip('package')` inline is even easier | ||
requires_matplotlib = pytest.mark.skipif(not has_matplotlib, reason='requires matplotlib') | ||
|
||
def requires_scipy(test): | ||
return test if has_scipy else unittest.skip('requires scipy')(test) | ||
return test if has_scipy else pytest.mark.skip('requires scipy')(test) | ||
|
||
|
||
def requires_pydap(test): | ||
return test if has_pydap else unittest.skip('requires pydap.client')(test) | ||
return test if has_pydap else pytest.mark.skip('requires pydap.client')(test) | ||
|
||
|
||
def requires_netCDF4(test): | ||
return test if has_netCDF4 else unittest.skip('requires netCDF4')(test) | ||
return test if has_netCDF4 else pytest.mark.skip('requires netCDF4')(test) | ||
|
||
|
||
def requires_h5netcdf(test): | ||
return test if has_h5netcdf else unittest.skip('requires h5netcdf')(test) | ||
return test if has_h5netcdf else pytest.mark.skip('requires h5netcdf')(test) | ||
|
||
|
||
def requires_pynio(test): | ||
return test if has_pynio else unittest.skip('requires pynio')(test) | ||
return test if has_pynio else pytest.mark.skip('requires pynio')(test) | ||
|
||
|
||
def requires_scipy_or_netCDF4(test): | ||
return (test if has_scipy or has_netCDF4 | ||
else unittest.skip('requires scipy or netCDF4')(test)) | ||
else pytest.mark.skip('requires scipy or netCDF4')(test)) | ||
|
||
|
||
def requires_dask(test): | ||
return test if has_dask else unittest.skip('requires dask')(test) | ||
|
||
|
||
def requires_matplotlib(test): | ||
return test if has_matplotlib else unittest.skip('requires matplotlib')(test) | ||
return test if has_dask else pytest.mark.skip('requires dask')(test) | ||
|
||
|
||
def requires_bottleneck(test): | ||
return test if has_bottleneck else unittest.skip('requires bottleneck')(test) | ||
return test if has_bottleneck else pytest.mark.skip('requires bottleneck')(test) | ||
|
||
|
||
def decode_string_data(data): | ||
|
@@ -154,67 +154,92 @@ def assertWarns(self, message): | |
assert any(message in str(wi.message) for wi in w) | ||
|
||
def assertVariableEqual(self, v1, v2): | ||
assert as_variable(v1).equals(v2), (v1, v2) | ||
assert_xarray_equal(v1, v2) | ||
|
||
def assertVariableIdentical(self, v1, v2): | ||
assert as_variable(v1).identical(v2), (v1, v2) | ||
assert_xarray_identical(v1, v2) | ||
|
||
def assertVariableAllClose(self, v1, v2, rtol=1e-05, atol=1e-08): | ||
self.assertEqual(v1.dims, v2.dims) | ||
allclose = data_allclose_or_equiv( | ||
v1.values, v2.values, rtol=rtol, atol=atol) | ||
assert allclose, (v1.values, v2.values) | ||
assert_xarray_allclose(v1, v2, rtol=rtol, atol=atol) | ||
|
||
def assertVariableNotEqual(self, v1, v2): | ||
self.assertFalse(as_variable(v1).equals(v2)) | ||
assert not v1.equals(v2) | ||
|
||
def assertArrayEqual(self, a1, a2): | ||
assert_array_equal(a1, a2) | ||
|
||
# TODO: write a generic "assertEqual" that uses the equals method, or just | ||
# switch to py.test and add an appropriate hook. | ||
|
||
def assertEqual(self, a1, a2): | ||
assert a1 == a2 or (a1 != a1 and a2 != a2) | ||
|
||
def assertDatasetEqual(self, d1, d2): | ||
# this method is functionally equivalent to `assert d1 == d2`, but it | ||
# checks each aspect of equality separately for easier debugging | ||
assert d1.equals(d2), (d1, d2) | ||
assert_xarray_equal(d1, d2) | ||
|
||
def assertDatasetIdentical(self, d1, d2): | ||
# this method is functionally equivalent to `assert d1.identical(d2)`, | ||
# but it checks each aspect of equality separately for easier debugging | ||
assert d1.identical(d2), (d1, d2) | ||
assert_xarray_identical(d1, d2) | ||
|
||
def assertDatasetAllClose(self, d1, d2, rtol=1e-05, atol=1e-08): | ||
self.assertEqual(sorted(d1, key=str), sorted(d2, key=str)) | ||
self.assertItemsEqual(d1.coords, d2.coords) | ||
for k in d1: | ||
v1 = d1.variables[k] | ||
v2 = d2.variables[k] | ||
self.assertVariableAllClose(v1, v2, rtol=rtol, atol=atol) | ||
assert_xarray_allclose(d1, d2, rtol=rtol, atol=atol) | ||
|
||
def assertCoordinatesEqual(self, d1, d2): | ||
self.assertEqual(sorted(d1.coords), sorted(d2.coords)) | ||
for k in d1.coords: | ||
v1 = d1.coords[k] | ||
v2 = d2.coords[k] | ||
self.assertVariableEqual(v1, v2) | ||
assert_xarray_equal(d1, d2) | ||
|
||
def assertDataArrayEqual(self, ar1, ar2): | ||
self.assertVariableEqual(ar1, ar2) | ||
self.assertCoordinatesEqual(ar1, ar2) | ||
assert_xarray_equal(ar1, ar2) | ||
|
||
def assertDataArrayIdentical(self, ar1, ar2): | ||
self.assertEqual(ar1.name, ar2.name) | ||
self.assertDatasetIdentical(ar1._to_temp_dataset(), | ||
ar2._to_temp_dataset()) | ||
assert_xarray_identical(ar1, ar2) | ||
|
||
def assertDataArrayAllClose(self, ar1, ar2, rtol=1e-05, atol=1e-08): | ||
self.assertVariableAllClose(ar1, ar2, rtol=rtol, atol=atol) | ||
self.assertCoordinatesEqual(ar1, ar2) | ||
assert_xarray_allclose(ar1, ar2, rtol=rtol, atol=atol) | ||
|
||
def assert_xarray_equal(a, b): | ||
import xarray as xr | ||
___tracebackhide__ = True # noqa: F841 | ||
assert type(a) == type(b) | ||
if isinstance(a, (xr.Variable, xr.DataArray, xr.Dataset)): | ||
assert a.equals(b), '{}\n{}'.format(a, b) | ||
else: | ||
raise TypeError('{} not supported by assertion comparison' | ||
.format(type(a))) | ||
|
||
def assert_xarray_identical(a, b): | ||
import xarray as xr | ||
___tracebackhide__ = True # noqa: F841 | ||
assert type(a) == type(b) | ||
if isinstance(a, xr.DataArray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you have this mixed up with the |
||
assert a.name == b.name | ||
assert_xarray_identical(a._to_temp_dataset(), b._to_temp_dataset()) | ||
elif isinstance(a, (xr.Dataset, xr.Variable)): | ||
assert a.identical(b), '{}\n{}'.format(a, b) | ||
else: | ||
raise TypeError('{} not supported by assertion comparison' | ||
.format(type(a))) | ||
|
||
def assert_xarray_allclose(a, b, rtol=1e-05, atol=1e-08): | ||
import xarray as xr | ||
___tracebackhide__ = True # noqa: F841 | ||
assert type(a) == type(b) | ||
if isinstance(a, xr.Variable): | ||
assert a.dims == b.dims | ||
allclose = data_allclose_or_equiv( | ||
a.values, b.values, rtol=rtol, atol=atol) | ||
assert allclose, '{}\n{}'.format(a.values, b.values) | ||
elif isinstance(a, xr.DataArray): | ||
assert_xarray_allclose(a.variable, b.variable) | ||
for v in a.coords.variables: | ||
# can't recurse with this function as coord is sometimes a DataArray, | ||
# so call into data_allclose_or_equiv directly | ||
allclose = data_allclose_or_equiv( | ||
a.coords[v].values, b.coords[v].values, rtol=rtol, atol=atol) | ||
assert allclose, '{}\n{}'.format(a.coords[v].values, b.coords[v].values) | ||
elif isinstance(a, xr.Dataset): | ||
assert sorted(a, key=str) == sorted(b, key=str) | ||
for k in list(a.variables) + list(a.coords): | ||
assert_xarray_allclose(a[k], b[k], rtol=rtol, atol=atol) | ||
|
||
else: | ||
raise TypeError('{} not supported by assertion comparison' | ||
.format(type(a))) | ||
|
||
class UnexpectedDataAccess(Exception): | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, | ||
requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, | ||
requires_pynio, has_netCDF4, has_scipy) | ||
requires_pynio, has_netCDF4, has_scipy, assert_xarray_allclose) | ||
from .test_dataset import create_test_data | ||
|
||
try: | ||
|
@@ -492,7 +492,7 @@ def create_tmp_file(suffix='.nc', allow_cleanup_failure=False): | |
if not allow_cleanup_failure: | ||
raise | ||
|
||
|
||
@requires_netCDF4 | ||
class BaseNetCDF4Test(CFEncodedDataTest): | ||
def test_open_group(self): | ||
# Create a netCDF file with a dataset stored within a group | ||
|
@@ -693,6 +693,7 @@ def test_variable_len_strings(self): | |
|
||
@requires_netCDF4 | ||
class NetCDF4DataTest(BaseNetCDF4Test, TestCase): | ||
|
||
@contextlib.contextmanager | ||
def create_store(self): | ||
with create_tmp_file() as tmp_file: | ||
|
@@ -912,7 +913,12 @@ def test_cross_engine_read_write_netcdf3(self): | |
for read_engine in valid_engines: | ||
with open_dataset(tmp_file, | ||
engine=read_engine) as actual: | ||
self.assertDatasetAllClose(data, actual) | ||
# hack to allow test to work: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now, but we should look into this later because it looks like a legit bug to me. It might be worth making a dedicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed it's a bug |
||
# coord comes back as DataArray rather than coord, and so | ||
# need to loop through here rather than in the test | ||
# function (or we get recursion) | ||
[assert_xarray_allclose(data[k].variable, actual[k].variable) | ||
for k in data] | ||
|
||
|
||
@requires_h5netcdf | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dask -> netCDF4