Skip to content

Pytest assert functions #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Dec 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ nosetests.xml
# PyCharm and Vim
.idea
*.swp
.DS_Store

# xarray specific
doc/_build
Expand Down
11 changes: 11 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,14 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods:
backends.H5NetCDFStore
backends.PydapDataStore
backends.ScipyDataStore


Testing
=======

.. autosummary::
:toctree: generated/

test.assert_xarray_equal
test.assert_xarray_identical
test.assert_xarray_allclose
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ Enhancements
:py:class:`FacetGrid` and :py:func:`~xarray.plot.plot`, so axes
sharing can be disabled for polar plots.
By `Bas Hoonhout <https://github.com/hoonhout>`_.
- New utility functions :py:func:`~xarray.test.assert_xarray_equal`,
:py:func:`~xarray.test.assert_xarray_identical`, and
:py:func:`~xarray.test.assert_xarray_allclose` for asserting relationships
between xarray objects, designed for use in a pytest test suite.
- ``figsize``, ``size`` and ``aspect`` plot arguments are now supported for all
plots (:issue:`897`). See :ref:`plotting.figsize` for more details.
By `Stephan Hoyer <https://github.com/shoyer>`_ and
Expand Down
117 changes: 71 additions & 46 deletions xarray/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from numpy.testing import assert_array_equal
import pytest

from xarray.core import utils, nputils, ops
from xarray.core.variable import as_variable
Expand Down Expand Up @@ -71,42 +72,41 @@
except ImportError:
has_bottleneck = False

# slighly simpler construction that the full functions.
# Generally `pytest.importorskip('package')` inline is even easier
requires_matplotlib = pytest.mark.skipif(not has_matplotlib, reason='requires matplotlib')

def requires_scipy(test):
return test if has_scipy else unittest.skip('requires scipy')(test)
return test if has_scipy else pytest.mark.skip('requires scipy')(test)


def requires_pydap(test):
return test if has_pydap else unittest.skip('requires pydap.client')(test)
return test if has_pydap else pytest.mark.skip('requires pydap.client')(test)


def requires_netCDF4(test):
return test if has_netCDF4 else unittest.skip('requires netCDF4')(test)
return test if has_netCDF4 else pytest.mark.skip('requires netCDF4')(test)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dask -> netCDF4


def requires_h5netcdf(test):
return test if has_h5netcdf else unittest.skip('requires h5netcdf')(test)
return test if has_h5netcdf else pytest.mark.skip('requires h5netcdf')(test)


def requires_pynio(test):
return test if has_pynio else unittest.skip('requires pynio')(test)
return test if has_pynio else pytest.mark.skip('requires pynio')(test)


def requires_scipy_or_netCDF4(test):
return (test if has_scipy or has_netCDF4
else unittest.skip('requires scipy or netCDF4')(test))
else pytest.mark.skip('requires scipy or netCDF4')(test))


def requires_dask(test):
return test if has_dask else unittest.skip('requires dask')(test)


def requires_matplotlib(test):
return test if has_matplotlib else unittest.skip('requires matplotlib')(test)
return test if has_dask else pytest.mark.skip('requires dask')(test)


def requires_bottleneck(test):
return test if has_bottleneck else unittest.skip('requires bottleneck')(test)
return test if has_bottleneck else pytest.mark.skip('requires bottleneck')(test)


def decode_string_data(data):
Expand Down Expand Up @@ -154,67 +154,92 @@ def assertWarns(self, message):
assert any(message in str(wi.message) for wi in w)

def assertVariableEqual(self, v1, v2):
assert as_variable(v1).equals(v2), (v1, v2)
assert_xarray_equal(v1, v2)

def assertVariableIdentical(self, v1, v2):
assert as_variable(v1).identical(v2), (v1, v2)
assert_xarray_identical(v1, v2)

def assertVariableAllClose(self, v1, v2, rtol=1e-05, atol=1e-08):
self.assertEqual(v1.dims, v2.dims)
allclose = data_allclose_or_equiv(
v1.values, v2.values, rtol=rtol, atol=atol)
assert allclose, (v1.values, v2.values)
assert_xarray_allclose(v1, v2, rtol=rtol, atol=atol)

def assertVariableNotEqual(self, v1, v2):
self.assertFalse(as_variable(v1).equals(v2))
assert not v1.equals(v2)

def assertArrayEqual(self, a1, a2):
assert_array_equal(a1, a2)

# TODO: write a generic "assertEqual" that uses the equals method, or just
# switch to py.test and add an appropriate hook.

def assertEqual(self, a1, a2):
assert a1 == a2 or (a1 != a1 and a2 != a2)

def assertDatasetEqual(self, d1, d2):
# this method is functionally equivalent to `assert d1 == d2`, but it
# checks each aspect of equality separately for easier debugging
assert d1.equals(d2), (d1, d2)
assert_xarray_equal(d1, d2)

def assertDatasetIdentical(self, d1, d2):
# this method is functionally equivalent to `assert d1.identical(d2)`,
# but it checks each aspect of equality separately for easier debugging
assert d1.identical(d2), (d1, d2)
assert_xarray_identical(d1, d2)

def assertDatasetAllClose(self, d1, d2, rtol=1e-05, atol=1e-08):
self.assertEqual(sorted(d1, key=str), sorted(d2, key=str))
self.assertItemsEqual(d1.coords, d2.coords)
for k in d1:
v1 = d1.variables[k]
v2 = d2.variables[k]
self.assertVariableAllClose(v1, v2, rtol=rtol, atol=atol)
assert_xarray_allclose(d1, d2, rtol=rtol, atol=atol)

def assertCoordinatesEqual(self, d1, d2):
self.assertEqual(sorted(d1.coords), sorted(d2.coords))
for k in d1.coords:
v1 = d1.coords[k]
v2 = d2.coords[k]
self.assertVariableEqual(v1, v2)
assert_xarray_equal(d1, d2)

def assertDataArrayEqual(self, ar1, ar2):
self.assertVariableEqual(ar1, ar2)
self.assertCoordinatesEqual(ar1, ar2)
assert_xarray_equal(ar1, ar2)

def assertDataArrayIdentical(self, ar1, ar2):
self.assertEqual(ar1.name, ar2.name)
self.assertDatasetIdentical(ar1._to_temp_dataset(),
ar2._to_temp_dataset())
assert_xarray_identical(ar1, ar2)

def assertDataArrayAllClose(self, ar1, ar2, rtol=1e-05, atol=1e-08):
self.assertVariableAllClose(ar1, ar2, rtol=rtol, atol=atol)
self.assertCoordinatesEqual(ar1, ar2)
assert_xarray_allclose(ar1, ar2, rtol=rtol, atol=atol)

def assert_xarray_equal(a, b):
import xarray as xr
___tracebackhide__ = True # noqa: F841
assert type(a) == type(b)
if isinstance(a, (xr.Variable, xr.DataArray, xr.Dataset)):
assert a.equals(b), '{}\n{}'.format(a, b)
else:
raise TypeError('{} not supported by assertion comparison'
.format(type(a)))

def assert_xarray_identical(a, b):
import xarray as xr
___tracebackhide__ = True # noqa: F841
assert type(a) == type(b)
if isinstance(a, xr.DataArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have this mixed up with the Dataset clause?

assert a.name == b.name
assert_xarray_identical(a._to_temp_dataset(), b._to_temp_dataset())
elif isinstance(a, (xr.Dataset, xr.Variable)):
assert a.identical(b), '{}\n{}'.format(a, b)
else:
raise TypeError('{} not supported by assertion comparison'
.format(type(a)))

def assert_xarray_allclose(a, b, rtol=1e-05, atol=1e-08):
import xarray as xr
___tracebackhide__ = True # noqa: F841
assert type(a) == type(b)
if isinstance(a, xr.Variable):
assert a.dims == b.dims
allclose = data_allclose_or_equiv(
a.values, b.values, rtol=rtol, atol=atol)
assert allclose, '{}\n{}'.format(a.values, b.values)
elif isinstance(a, xr.DataArray):
assert_xarray_allclose(a.variable, b.variable)
for v in a.coords.variables:
# can't recurse with this function as coord is sometimes a DataArray,
# so call into data_allclose_or_equiv directly
allclose = data_allclose_or_equiv(
a.coords[v].values, b.coords[v].values, rtol=rtol, atol=atol)
assert allclose, '{}\n{}'.format(a.coords[v].values, b.coords[v].values)
elif isinstance(a, xr.Dataset):
assert sorted(a, key=str) == sorted(b, key=str)
for k in list(a.variables) + list(a.coords):
assert_xarray_allclose(a[k], b[k], rtol=rtol, atol=atol)

else:
raise TypeError('{} not supported by assertion comparison'
.format(type(a)))

class UnexpectedDataAccess(Exception):
pass
Expand Down
12 changes: 9 additions & 3 deletions xarray/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap,
requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf,
requires_pynio, has_netCDF4, has_scipy)
requires_pynio, has_netCDF4, has_scipy, assert_xarray_allclose)
from .test_dataset import create_test_data

try:
Expand Down Expand Up @@ -492,7 +492,7 @@ def create_tmp_file(suffix='.nc', allow_cleanup_failure=False):
if not allow_cleanup_failure:
raise


@requires_netCDF4
class BaseNetCDF4Test(CFEncodedDataTest):
def test_open_group(self):
# Create a netCDF file with a dataset stored within a group
Expand Down Expand Up @@ -693,6 +693,7 @@ def test_variable_len_strings(self):

@requires_netCDF4
class NetCDF4DataTest(BaseNetCDF4Test, TestCase):

@contextlib.contextmanager
def create_store(self):
with create_tmp_file() as tmp_file:
Expand Down Expand Up @@ -912,7 +913,12 @@ def test_cross_engine_read_write_netcdf3(self):
for read_engine in valid_engines:
with open_dataset(tmp_file,
engine=read_engine) as actual:
self.assertDatasetAllClose(data, actual)
# hack to allow test to work:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but we should look into this later because it looks like a legit bug to me.

It might be worth making a dedicated VariablesDict that is simply an OrderedDict with runtime type checking that verifies you can never put in anything other than a Variable. Or we should use pytype for static checks to ensure Dataset._variables and DataArray._coords are typed typing.OrderedDict[Any, xarray.Variable].

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it's a bug

# coord comes back as DataArray rather than coord, and so
# need to loop through here rather than in the test
# function (or we get recursion)
[assert_xarray_allclose(data[k].variable, actual[k].variable)
for k in data]


@requires_h5netcdf
Expand Down
Loading