diff --git a/.gitignore b/.gitignore index 0f0e5aeb0d6..13b75724e36 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ nosetests.xml # PyCharm and Vim .idea *.swp +.DS_Store # xarray specific doc/_build diff --git a/doc/api.rst b/doc/api.rst index c66e61dddf8..1e58e83c4a2 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -471,3 +471,14 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + + +Testing +======= + +.. autosummary:: + :toctree: generated/ + + test.assert_xarray_equal + test.assert_xarray_identical + test.assert_xarray_allclose diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 05828051b6d..784787a378f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -144,6 +144,10 @@ Enhancements :py:class:`FacetGrid` and :py:func:`~xarray.plot.plot`, so axes sharing can be disabled for polar plots. By `Bas Hoonhout `_. +- New utility functions :py:func:`~xarray.test.assert_xarray_equal`, + :py:func:`~xarray.test.assert_xarray_identical`, and + :py:func:`~xarray.test.assert_xarray_allclose` for asserting relationships + between xarray objects, designed for use in a pytest test suite. - ``figsize``, ``size`` and ``aspect`` plot arguments are now supported for all plots (:issue:`897`). See :ref:`plotting.figsize` for more details. By `Stephan Hoyer `_ and diff --git a/xarray/test/__init__.py b/xarray/test/__init__.py index bdafef7c3ad..69cbe25a87d 100644 --- a/xarray/test/__init__.py +++ b/xarray/test/__init__.py @@ -6,6 +6,7 @@ import numpy as np from numpy.testing import assert_array_equal +import pytest from xarray.core import utils, nputils, ops from xarray.core.variable import as_variable @@ -71,42 +72,41 @@ except ImportError: has_bottleneck = False +# slighly simpler construction that the full functions. +# Generally `pytest.importorskip('package')` inline is even easier +requires_matplotlib = pytest.mark.skipif(not has_matplotlib, reason='requires matplotlib') def requires_scipy(test): - return test if has_scipy else unittest.skip('requires scipy')(test) + return test if has_scipy else pytest.mark.skip('requires scipy')(test) def requires_pydap(test): - return test if has_pydap else unittest.skip('requires pydap.client')(test) + return test if has_pydap else pytest.mark.skip('requires pydap.client')(test) def requires_netCDF4(test): - return test if has_netCDF4 else unittest.skip('requires netCDF4')(test) + return test if has_netCDF4 else pytest.mark.skip('requires netCDF4')(test) def requires_h5netcdf(test): - return test if has_h5netcdf else unittest.skip('requires h5netcdf')(test) + return test if has_h5netcdf else pytest.mark.skip('requires h5netcdf')(test) def requires_pynio(test): - return test if has_pynio else unittest.skip('requires pynio')(test) + return test if has_pynio else pytest.mark.skip('requires pynio')(test) def requires_scipy_or_netCDF4(test): return (test if has_scipy or has_netCDF4 - else unittest.skip('requires scipy or netCDF4')(test)) + else pytest.mark.skip('requires scipy or netCDF4')(test)) def requires_dask(test): - return test if has_dask else unittest.skip('requires dask')(test) - - -def requires_matplotlib(test): - return test if has_matplotlib else unittest.skip('requires matplotlib')(test) + return test if has_dask else pytest.mark.skip('requires dask')(test) def requires_bottleneck(test): - return test if has_bottleneck else unittest.skip('requires bottleneck')(test) + return test if has_bottleneck else pytest.mark.skip('requires bottleneck')(test) def decode_string_data(data): @@ -154,67 +154,92 @@ def assertWarns(self, message): assert any(message in str(wi.message) for wi in w) def assertVariableEqual(self, v1, v2): - assert as_variable(v1).equals(v2), (v1, v2) + assert_xarray_equal(v1, v2) def assertVariableIdentical(self, v1, v2): - assert as_variable(v1).identical(v2), (v1, v2) + assert_xarray_identical(v1, v2) def assertVariableAllClose(self, v1, v2, rtol=1e-05, atol=1e-08): - self.assertEqual(v1.dims, v2.dims) - allclose = data_allclose_or_equiv( - v1.values, v2.values, rtol=rtol, atol=atol) - assert allclose, (v1.values, v2.values) + assert_xarray_allclose(v1, v2, rtol=rtol, atol=atol) def assertVariableNotEqual(self, v1, v2): - self.assertFalse(as_variable(v1).equals(v2)) + assert not v1.equals(v2) def assertArrayEqual(self, a1, a2): assert_array_equal(a1, a2) - # TODO: write a generic "assertEqual" that uses the equals method, or just - # switch to py.test and add an appropriate hook. - def assertEqual(self, a1, a2): assert a1 == a2 or (a1 != a1 and a2 != a2) def assertDatasetEqual(self, d1, d2): - # this method is functionally equivalent to `assert d1 == d2`, but it - # checks each aspect of equality separately for easier debugging - assert d1.equals(d2), (d1, d2) + assert_xarray_equal(d1, d2) def assertDatasetIdentical(self, d1, d2): - # this method is functionally equivalent to `assert d1.identical(d2)`, - # but it checks each aspect of equality separately for easier debugging - assert d1.identical(d2), (d1, d2) + assert_xarray_identical(d1, d2) def assertDatasetAllClose(self, d1, d2, rtol=1e-05, atol=1e-08): - self.assertEqual(sorted(d1, key=str), sorted(d2, key=str)) - self.assertItemsEqual(d1.coords, d2.coords) - for k in d1: - v1 = d1.variables[k] - v2 = d2.variables[k] - self.assertVariableAllClose(v1, v2, rtol=rtol, atol=atol) + assert_xarray_allclose(d1, d2, rtol=rtol, atol=atol) def assertCoordinatesEqual(self, d1, d2): - self.assertEqual(sorted(d1.coords), sorted(d2.coords)) - for k in d1.coords: - v1 = d1.coords[k] - v2 = d2.coords[k] - self.assertVariableEqual(v1, v2) + assert_xarray_equal(d1, d2) def assertDataArrayEqual(self, ar1, ar2): - self.assertVariableEqual(ar1, ar2) - self.assertCoordinatesEqual(ar1, ar2) + assert_xarray_equal(ar1, ar2) def assertDataArrayIdentical(self, ar1, ar2): - self.assertEqual(ar1.name, ar2.name) - self.assertDatasetIdentical(ar1._to_temp_dataset(), - ar2._to_temp_dataset()) + assert_xarray_identical(ar1, ar2) def assertDataArrayAllClose(self, ar1, ar2, rtol=1e-05, atol=1e-08): - self.assertVariableAllClose(ar1, ar2, rtol=rtol, atol=atol) - self.assertCoordinatesEqual(ar1, ar2) + assert_xarray_allclose(ar1, ar2, rtol=rtol, atol=atol) + +def assert_xarray_equal(a, b): + import xarray as xr + ___tracebackhide__ = True # noqa: F841 + assert type(a) == type(b) + if isinstance(a, (xr.Variable, xr.DataArray, xr.Dataset)): + assert a.equals(b), '{}\n{}'.format(a, b) + else: + raise TypeError('{} not supported by assertion comparison' + .format(type(a))) + +def assert_xarray_identical(a, b): + import xarray as xr + ___tracebackhide__ = True # noqa: F841 + assert type(a) == type(b) + if isinstance(a, xr.DataArray): + assert a.name == b.name + assert_xarray_identical(a._to_temp_dataset(), b._to_temp_dataset()) + elif isinstance(a, (xr.Dataset, xr.Variable)): + assert a.identical(b), '{}\n{}'.format(a, b) + else: + raise TypeError('{} not supported by assertion comparison' + .format(type(a))) + +def assert_xarray_allclose(a, b, rtol=1e-05, atol=1e-08): + import xarray as xr + ___tracebackhide__ = True # noqa: F841 + assert type(a) == type(b) + if isinstance(a, xr.Variable): + assert a.dims == b.dims + allclose = data_allclose_or_equiv( + a.values, b.values, rtol=rtol, atol=atol) + assert allclose, '{}\n{}'.format(a.values, b.values) + elif isinstance(a, xr.DataArray): + assert_xarray_allclose(a.variable, b.variable) + for v in a.coords.variables: + # can't recurse with this function as coord is sometimes a DataArray, + # so call into data_allclose_or_equiv directly + allclose = data_allclose_or_equiv( + a.coords[v].values, b.coords[v].values, rtol=rtol, atol=atol) + assert allclose, '{}\n{}'.format(a.coords[v].values, b.coords[v].values) + elif isinstance(a, xr.Dataset): + assert sorted(a, key=str) == sorted(b, key=str) + for k in list(a.variables) + list(a.coords): + assert_xarray_allclose(a[k], b[k], rtol=rtol, atol=atol) + else: + raise TypeError('{} not supported by assertion comparison' + .format(type(a))) class UnexpectedDataAccess(Exception): pass diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 82001fffb83..8c1e00a698c 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -26,7 +26,7 @@ from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, - requires_pynio, has_netCDF4, has_scipy) + requires_pynio, has_netCDF4, has_scipy, assert_xarray_allclose) from .test_dataset import create_test_data try: @@ -492,7 +492,7 @@ def create_tmp_file(suffix='.nc', allow_cleanup_failure=False): if not allow_cleanup_failure: raise - +@requires_netCDF4 class BaseNetCDF4Test(CFEncodedDataTest): def test_open_group(self): # Create a netCDF file with a dataset stored within a group @@ -693,6 +693,7 @@ def test_variable_len_strings(self): @requires_netCDF4 class NetCDF4DataTest(BaseNetCDF4Test, TestCase): + @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -912,7 +913,12 @@ def test_cross_engine_read_write_netcdf3(self): for read_engine in valid_engines: with open_dataset(tmp_file, engine=read_engine) as actual: - self.assertDatasetAllClose(data, actual) + # hack to allow test to work: + # coord comes back as DataArray rather than coord, and so + # need to loop through here rather than in the test + # function (or we get recursion) + [assert_xarray_allclose(data[k].variable, actual[k].variable) + for k in data] @requires_h5netcdf diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index bd6c5ccff92..9b0a4485287 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -15,8 +15,10 @@ from xarray.core.pycompat import iteritems, OrderedDict from xarray.core.common import full_like -from xarray.test import (TestCase, ReturnItem, source_ndarray, unittest, - requires_dask, requires_bottleneck) +from xarray.test import ( + TestCase, ReturnItem, source_ndarray, unittest, requires_dask, + assert_xarray_identical, assert_xarray_equal, + assert_xarray_allclose, assert_array_equal) class TestDataArray(TestCase): @@ -65,7 +67,7 @@ def test_properties(self): for attr in ['dims', 'dtype', 'shape', 'size', 'nbytes', 'ndim', 'attrs']: self.assertEqual(getattr(self.dv, attr), getattr(self.v, attr)) self.assertEqual(len(self.dv), len(self.v)) - self.assertVariableEqual(self.dv, self.v) + self.assertVariableEqual(self.dv.variable, self.v) self.assertItemsEqual(list(self.dv.coords), list(self.ds.coords)) for k, v in iteritems(self.dv.coords): self.assertArrayEqual(v, self.ds.coords[k]) @@ -416,7 +418,7 @@ def test_getitem(self): for i in [I[0], I[:, 0], I[:3, :2], I[x.values[:3]], I[x.variable[:3]], I[x[:3]], I[x[:3], y[:4]], I[x.values > 3], I[x.variable > 3], I[x > 3], I[x > 3, y > 3]]: - self.assertVariableEqual(self.v[i], self.dv[i]) + assert_array_equal(self.v[i], self.dv[i]) def test_getitem_dict(self): actual = self.dv[{'x': slice(3), 'y': 0}] @@ -648,7 +650,7 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, if renamed_dim: self.assertEqual(da.dims[0], renamed_dim) da = da.rename({renamed_dim: 'x'}) - self.assertVariableIdentical(da, expected_da) + self.assertVariableIdentical(da.variable, expected_da.variable) self.assertVariableNotEqual(da['x'], expected_da['x']) test_sel(('a', 1, -1), 0) @@ -917,13 +919,13 @@ def test_array_interface(self): self.assertArrayEqual(np.asarray(self.dv), self.x) # test patched in methods self.assertArrayEqual(self.dv.astype(float), self.v.astype(float)) - self.assertVariableEqual(self.dv.argsort(), self.v.argsort()) - self.assertVariableEqual(self.dv.clip(2, 3), self.v.clip(2, 3)) + assert_array_equal(self.dv.argsort(), self.v.argsort()) + assert_array_equal(self.dv.clip(2, 3), self.v.clip(2, 3)) # test ufuncs expected = deepcopy(self.ds) expected['foo'][:] = np.sin(self.x) self.assertDataArrayEqual(expected['foo'], np.sin(self.dv)) - self.assertDataArrayEqual(self.dv, np.maximum(self.v, self.dv)) + assert_array_equal(self.dv, np.maximum(self.v, self.dv)) bar = Variable(['x', 'y'], np.zeros((10, 20))) self.assertDataArrayEqual(self.dv, np.maximum(self.dv, bar)) @@ -1137,10 +1139,10 @@ def test_unstack_pandas_consistency(self): def test_transpose(self): self.assertVariableEqual(self.dv.variable.transpose(), - self.dv.transpose()) + self.dv.transpose().variable) def test_squeeze(self): - self.assertVariableEqual(self.dv.variable.squeeze(), self.dv.squeeze()) + self.assertVariableEqual(self.dv.variable.squeeze(), self.dv.squeeze().variable) def test_squeeze_drop(self): array = DataArray([1], [('x', [0])]) @@ -1247,7 +1249,7 @@ def test_reduce(self): expected = DataArray([0, 0], {'x': coords['x'], 'c': -999}, 'x') self.assertDataArrayIdentical(expected, actual) - self.assertVariableEqual(self.dv.reduce(np.mean, 'x'), + self.assertVariableEqual(self.dv.reduce(np.mean, 'x').variable, self.v.reduce(np.mean, 'x')) orig = DataArray([[1, 0, np.nan], [3, 0, 3]], coords, dims=['x', 'y']) @@ -1586,119 +1588,6 @@ def test_groupby_bins_sort(self): binned_mean = data.groupby_bins('x', bins=11).mean() assert binned_mean.to_index().is_monotonic - def make_rolling_example_array(self): - times = pd.date_range('2000-01-01', freq='1D', periods=21) - values = np.random.random((21, 4)) - da = DataArray(values, dims=('time', 'x')) - da['time'] = times - - return da - - def test_rolling_iter(self): - da = self.make_rolling_example_array() - - rolling_obj = da.rolling(time=7) - - self.assertEqual(len(rolling_obj.window_labels), len(da['time'])) - self.assertDataArrayIdentical(rolling_obj.window_labels, da['time']) - - for i, (label, window_da) in enumerate(rolling_obj): - self.assertEqual(label, da['time'].isel(time=i)) - - def test_rolling_properties(self): - da = self.make_rolling_example_array() - rolling_obj = da.rolling(time=4) - - self.assertEqual(rolling_obj._axis_num, 0) - - # catching invalid args - with self.assertRaisesRegexp(ValueError, 'exactly one dim/window should'): - da.rolling(time=7, x=2) - with self.assertRaisesRegexp(ValueError, 'window must be > 0'): - da.rolling(time=-2) - with self.assertRaisesRegexp(ValueError, 'min_periods must be greater'): - da.rolling(time=2, min_periods=0) - - @requires_bottleneck - def test_rolling_wrapped_bottleneck(self): - import bottleneck as bn - - da = self.make_rolling_example_array() - - # Test all bottleneck functions - rolling_obj = da.rolling(time=7) - for name in ('sum', 'mean', 'std', 'min', 'max', 'median'): - func_name = 'move_{0}'.format(name) - actual = getattr(rolling_obj, name)() - expected = getattr(bn, func_name)(da.values, window=7, axis=0) - self.assertArrayEqual(actual.values, expected) - - # Using min_periods - rolling_obj = da.rolling(time=7, min_periods=1) - for name in ('sum', 'mean', 'std', 'min', 'max'): - func_name = 'move_{0}'.format(name) - actual = getattr(rolling_obj, name)() - expected = getattr(bn, func_name)(da.values, window=7, axis=0, - min_count=1) - self.assertArrayEqual(actual.values, expected) - - # Using center=False - rolling_obj = da.rolling(time=7, center=False) - for name in ('sum', 'mean', 'std', 'min', 'max', 'median'): - actual = getattr(rolling_obj, name)()['time'] - self.assertDataArrayEqual(actual, da['time']) - - # Using center=True - rolling_obj = da.rolling(time=7, center=True) - for name in ('sum', 'mean', 'std', 'min', 'max', 'median'): - actual = getattr(rolling_obj, name)()['time'] - self.assertDataArrayEqual(actual, da['time']) - - # catching invalid args - with self.assertRaisesRegexp(ValueError, 'Rolling.median does not'): - da.rolling(time=7, min_periods=1).median() - - def test_rolling_pandas_compat(self): - s = pd.Series(range(10)) - da = DataArray.from_series(s) - - for center in (False, True): - for window in [1, 2, 3, 4]: - for min_periods in [None, 1, 2, 3]: - if min_periods is not None and window < min_periods: - min_periods = window - s_rolling = pd.rolling_mean(s, window, center=center, - min_periods=min_periods) - da_rolling = da.rolling(index=window, center=center, - min_periods=min_periods).mean() - # pandas does some fancy stuff in the last position, - # we're not going to do that yet! - np.testing.assert_allclose(s_rolling.values[:-1], - da_rolling.values[:-1]) - np.testing.assert_allclose(s_rolling.index, - da_rolling['index']) - - def test_rolling_reduce(self): - da = self.make_rolling_example_array() - for da in [self.make_rolling_example_array(), - DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], - dims='time')]: - for center in (False, True): - for window in [1, 2, 3, 4]: - for min_periods in [None, 1, 2, 3]: - if min_periods is not None and window < min_periods: - min_periods = window - # we can use this rolling object for all methods below - rolling_obj = da.rolling(time=window, center=center, - min_periods=min_periods) - for name in ['sum', 'mean', 'min', 'max']: - # add nan prefix to numpy methods to get similar - # behavior as bottleneck - actual = rolling_obj.reduce( - getattr(np, 'nan%s' % name)) - expected = getattr(rolling_obj, name)() - self.assertDataArrayAllClose(actual, expected) - def test_resample(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) array = DataArray(np.arange(10), [('time', times)]) @@ -2409,3 +2298,115 @@ def test_binary_op_join_setting(self): join=align_type) expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])]) self.assertDataArrayEqual(actual, expected) + + +@pytest.fixture(params=[1]) +def da(request): + if request.param == 1: + times = pd.date_range('2000-01-01', freq='1D', periods=21) + values = np.random.random((21, 4)) + da = DataArray(values, dims=('time', 'x')) + da['time'] = times + return da + + if request.param == 2: + return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims='time') + +def test_rolling_iter(da): + + rolling_obj = da.rolling(time=7) + + assert len(rolling_obj.window_labels) == len(da['time']) + assert_xarray_identical(rolling_obj.window_labels, da['time']) + + for i, (label, window_da) in enumerate(rolling_obj): + assert label == da['time'].isel(time=i) + +def test_rolling_properties(da): + pytest.importorskip('bottleneck') + + rolling_obj = da.rolling(time=4) + + assert rolling_obj._axis_num == 0 + + # catching invalid args + with pytest.raises(ValueError) as exception: + da.rolling(time=7, x=2) + assert 'exactly one dim/window should' in str(exception) + with pytest.raises(ValueError) as exception: + da.rolling(time=-2) + assert 'window must be > 0' in str(exception) + with pytest.raises(ValueError) as exception: + da.rolling(time=2, min_periods=0) + assert 'min_periods must be greater than zero' in str(exception) + + +@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'min', 'max', 'median')) +@pytest.mark.parametrize('center', (True, False, None)) +@pytest.mark.parametrize('min_periods', (1, None)) +def test_rolling_wrapped_bottleneck(da, name, center, min_periods): + pytest.importorskip('bottleneck') + import bottleneck as bn + + # skip if median and min_periods + if (min_periods == 1) and (name == 'median'): + pytest.skip() + + # Test all bottleneck functions + rolling_obj = da.rolling(time=7, min_periods=min_periods) + + func_name = 'move_{0}'.format(name) + actual = getattr(rolling_obj, name)() + expected = getattr(bn, func_name)(da.values, window=7, axis=0, min_count=min_periods) + assert_array_equal(actual.values, expected) + + # Test center + rolling_obj = da.rolling(time=7, center=center) + actual = getattr(rolling_obj, name)()['time'] + assert_xarray_equal(actual, da['time']) + +def test_rolling_invalid_args(da): + pytest.importorskip('bottleneck') + with pytest.raises(ValueError) as exception: + da.rolling(time=7, min_periods=1).median() + assert 'Rolling.median does not' in str(exception) + + +@pytest.mark.parametrize('center', (True, False)) +@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) +@pytest.mark.parametrize('window', (1, 2, 3, 4)) +def test_rolling_pandas_compat(da, center, window, min_periods): + s = pd.Series(range(10)) + da = DataArray.from_series(s) + + if min_periods is not None and window < min_periods: + min_periods = window + + s_rolling = pd.rolling_mean(s, window, center=center, + min_periods=min_periods) + da_rolling = da.rolling(index=window, center=center, + min_periods=min_periods).mean() + # pandas does some fancy stuff in the last position, + # we're not going to do that yet! + np.testing.assert_allclose(s_rolling.values[:-1], + da_rolling.values[:-1]) + np.testing.assert_allclose(s_rolling.index, + da_rolling['index']) + +@pytest.mark.parametrize('da', (1, 2), indirect=True) +@pytest.mark.parametrize('center', (True, False)) +@pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) +@pytest.mark.parametrize('window', (1, 2, 3, 4)) +@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'max')) +def test_rolling_reduce(da, center, min_periods, window, name): + + if min_periods is not None and window < min_periods: + min_periods = window + + rolling_obj = da.rolling(time=window, center=center, + min_periods=min_periods) + + # add nan prefix to numpy methods to get similar # behavior as bottleneck + actual = rolling_obj.reduce(getattr(np, 'nan%s' % name)) + expected = getattr(rolling_obj, name)() + assert_xarray_allclose(actual, expected) diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index 1a970fe718d..be7266ee165 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -439,7 +439,7 @@ def test_modify_inplace(self): a['x'] = ('x', vec, attributes) self.assertTrue('x' in a.coords) self.assertIsInstance(a.coords['x'].to_index(), pd.Index) - self.assertVariableIdentical(a.coords['x'], a.variables['x']) + self.assertVariableIdentical(a.coords['x'].variable, a.variables['x']) b = Dataset() b['x'] = ('x', vec, attributes) self.assertVariableIdentical(a['x'], b['x']) @@ -473,8 +473,8 @@ def test_coords_properties(self): self.assertItemsEqual(['x', 'y', 'a', 'b'], list(data.coords)) - self.assertVariableIdentical(data.coords['x'], data['x'].variable) - self.assertVariableIdentical(data.coords['y'], data['y'].variable) + self.assertVariableIdentical(data.coords['x'].variable, data['x'].variable) + self.assertVariableIdentical(data.coords['y'].variable, data['y'].variable) self.assertIn('x', data.coords) self.assertIn('a', data.coords) @@ -1012,7 +1012,8 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, if renamed_dim: self.assertEqual(ds['var'].dims[0], renamed_dim) ds = ds.rename({renamed_dim: 'x'}) - self.assertVariableIdentical(ds['var'], expected_ds['var']) + self.assertVariableIdentical(ds['var'].variable, + expected_ds['var'].variable) self.assertVariableNotEqual(ds['x'], expected_ds['x']) test_sel(('a', 1, -1), 0) @@ -1162,7 +1163,7 @@ def test_align(self): self.assertDatasetIdentical(left2, right2) left2, right2 = align(left, right, join='outer') - self.assertVariableEqual(left2['dim3'], right2['dim3']) + self.assertVariableEqual(left2['dim3'].variable, right2['dim3'].variable) self.assertArrayEqual(left2['dim3'], union) self.assertDatasetIdentical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) @@ -1170,15 +1171,15 @@ def test_align(self): self.assertTrue(np.isnan(right2['var3'][:2]).all()) left2, right2 = align(left, right, join='left') - self.assertVariableEqual(left2['dim3'], right2['dim3']) - self.assertVariableEqual(left2['dim3'], left['dim3']) + self.assertVariableEqual(left2['dim3'].variable, right2['dim3'].variable) + self.assertVariableEqual(left2['dim3'].variable, left['dim3'].variable) self.assertDatasetIdentical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) self.assertTrue(np.isnan(right2['var3'][:2]).all()) left2, right2 = align(left, right, join='right') - self.assertVariableEqual(left2['dim3'], right2['dim3']) - self.assertVariableEqual(left2['dim3'], right['dim3']) + self.assertVariableEqual(left2['dim3'].variable, right2['dim3'].variable) + self.assertVariableEqual(left2['dim3'].variable, right['dim3'].variable) self.assertDatasetIdentical(left2.sel(dim3=intersection), right2.sel(dim3=intersection)) self.assertTrue(np.isnan(left2['var3'][-2:]).all()) @@ -1396,7 +1397,7 @@ def test_rename(self): dims[dims.index(name)] = newname self.assertVariableEqual(Variable(dims, v.values, v.attrs), - renamed[k]) + renamed[k].variable.to_base_variable()) self.assertEqual(v.encoding, renamed[k].encoding) self.assertEqual(type(v), type(renamed.variables[k])) @@ -1548,7 +1549,7 @@ def test_update_auto_align(self): def test_getitem(self): data = create_test_data() self.assertIsInstance(data['var1'], DataArray) - self.assertVariableEqual(data['var1'], data.variables['var1']) + self.assertVariableEqual(data['var1'].variable, data.variables['var1']) with self.assertRaises(KeyError): data['notfound'] with self.assertRaises(KeyError): @@ -1644,9 +1645,9 @@ def test_time_season(self): def test_slice_virtual_variable(self): data = create_test_data() - self.assertVariableEqual(data['time.dayofyear'][:10], + self.assertVariableEqual(data['time.dayofyear'][:10].variable, Variable(['time'], 1 + np.arange(10))) - self.assertVariableEqual(data['time.dayofyear'][0], Variable([], 1)) + self.assertVariableEqual(data['time.dayofyear'][0].variable, Variable([], 1)) def test_setitem(self): # assign a variable diff --git a/xarray/test/test_indexing.py b/xarray/test/test_indexing.py index 9d22b4f2c87..79e841e0f3b 100644 --- a/xarray/test/test_indexing.py +++ b/xarray/test/test_indexing.py @@ -119,11 +119,13 @@ def test_get_dim_indexers(self): self.assertEqual(dim_indexers, {'x': {'one': 'a', 'two': 1}}) with self.assertRaisesRegexp(ValueError, 'cannot combine'): - _ = indexing.get_dim_indexers(mdata, {'x': 'a', 'two': 1}) + indexing.get_dim_indexers(mdata, {'x': 'a', 'two': 1}) with self.assertRaisesRegexp(ValueError, 'do not exist'): - _ = indexing.get_dim_indexers(mdata, {'y': 'a'}) - _ = indexing.get_dim_indexers(data, {'four': 1}) + indexing.get_dim_indexers(mdata, {'y': 'a'}) + + with self.assertRaisesRegexp(ValueError, 'do not exist'): + indexing.get_dim_indexers(mdata, {'four': 1}) def test_remap_label_indexers(self): def test_indexer(data, x, expected_pos, expected_idx=None): diff --git a/xarray/test/test_variable.py b/xarray/test/test_variable.py index e3eced07716..e99e77abf99 100644 --- a/xarray/test/test_variable.py +++ b/xarray/test/test_variable.py @@ -213,15 +213,19 @@ def test_pandas_period_index(self): def test_1d_math(self): x = 1.0 * np.arange(5) y = np.ones(5) + + # should we need `.to_base_variable()`? + # probably a break that `+v` changes type? v = self.cls(['x'], x) + base_v = v.to_base_variable() # unary ops - self.assertVariableIdentical(v, +v) - self.assertVariableIdentical(v, abs(v)) + self.assertVariableIdentical(base_v, +v) + self.assertVariableIdentical(base_v, abs(v)) self.assertArrayEqual((-v).values, -x) # binary ops with numbers - self.assertVariableIdentical(v, v + 0) - self.assertVariableIdentical(v, 0 + v) - self.assertVariableIdentical(v, v * 1) + self.assertVariableIdentical(base_v, v + 0) + self.assertVariableIdentical(base_v, 0 + v) + self.assertVariableIdentical(base_v, v * 1) self.assertArrayEqual((v > 2).values, x > 2) self.assertArrayEqual((0 == v).values, 0 == x) self.assertArrayEqual((v - 1).values, x - 1) @@ -233,11 +237,11 @@ def test_1d_math(self): self.assertArrayEqual(y - v, 1 - v) # verify attributes are dropped v2 = self.cls(['x'], x, {'units': 'meters'}) - self.assertVariableIdentical(v, +v2) + self.assertVariableIdentical(base_v, +v2) # binary ops with all variables self.assertArrayEqual(v + v, 2 * v) w = self.cls(['x'], y, {'foo': 'bar'}) - self.assertVariableIdentical(v + w, self.cls(['x'], x + y)) + self.assertVariableIdentical(v + w, self.cls(['x'], x + y).to_base_variable()) self.assertArrayEqual((v * w).values, x * y) # something complicated self.assertArrayEqual((v ** 2 * w - 1 + x).values, x ** 2 * y - 1 + x) @@ -266,10 +270,12 @@ def test_array_interface(self): self.assertArrayEqual(np.asarray(v), x) # test patched in methods self.assertArrayEqual(v.astype(float), x.astype(float)) - self.assertVariableIdentical(v.argsort(), v) - self.assertVariableIdentical(v.clip(2, 3), self.cls('x', x.clip(2, 3))) + # think this is a break, that argsort changes the type + self.assertVariableIdentical(v.argsort(), v.to_base_variable()) + self.assertVariableIdentical(v.clip(2, 3), + self.cls('x', x.clip(2, 3)).to_base_variable()) # test ufuncs - self.assertVariableIdentical(np.sin(v), self.cls(['x'], np.sin(x))) + self.assertVariableIdentical(np.sin(v), self.cls(['x'], np.sin(x)).to_base_variable()) self.assertIsInstance(np.sin(v), Variable) self.assertNotIsInstance(np.sin(v), IndexVariable) @@ -304,7 +310,7 @@ def test_equals_all_dtypes(self): def test_eq_all_dtypes(self): # ensure that we don't choke on comparisons for which numpy returns # scalars - expected = self.cls('x', 3 * [False]) + expected = Variable('x', 3 * [False]) for v, _ in self.example_1d_objects(): actual = 'z' == v self.assertVariableIdentical(expected, actual) @@ -320,7 +326,9 @@ def test_encoding_preserved(self): expected.expand_dims({'x': 3}), expected.copy(deep=True), expected.copy(deep=False)]: - self.assertVariableIdentical(expected, actual) + + self.assertVariableIdentical(expected.to_base_variable(), + actual.to_base_variable()) self.assertEqual(expected.encoding, actual.encoding) def test_concat(self): @@ -357,7 +365,7 @@ def test_concat_attrs(self): # different or conflicting attributes should be removed v = self.cls('a', np.arange(5), {'foo': 'bar'}) w = self.cls('a', np.ones(5)) - expected = self.cls('a', np.concatenate([np.arange(5), np.ones(5)])) + expected = self.cls('a', np.concatenate([np.arange(5), np.ones(5)])).to_base_variable() self.assertVariableIdentical(expected, Variable.concat([v, w], 'a')) w.attrs['foo'] = 2 self.assertVariableIdentical(expected, Variable.concat([v, w], 'a')) @@ -419,7 +427,7 @@ def test_real_and_imag(self): expected_im = self.cls('x', -np.arange(3), {'foo': 'bar'}) self.assertVariableIdentical(v.imag, expected_im) - expected_abs = self.cls('x', np.sqrt(2 * np.arange(3) ** 2)) + expected_abs = self.cls('x', np.sqrt(2 * np.arange(3) ** 2)).to_base_variable() self.assertVariableAllClose(abs(v), expected_abs) def test_aggregate_complex(self): @@ -604,7 +612,7 @@ def test_as_variable(self): self.assertVariableIdentical(expected, as_variable(expected)) ds = Dataset({'x': expected}) - self.assertVariableIdentical(expected, as_variable(ds['x'])) + self.assertVariableIdentical(expected, as_variable(ds['x']).to_base_variable()) self.assertNotIsInstance(ds['x'], Variable) self.assertIsInstance(as_variable(ds['x']), Variable) @@ -622,8 +630,7 @@ def test_as_variable(self): as_variable(data) actual = as_variable(data, name='x') - self.assertVariableIdentical(expected, actual) - self.assertIsInstance(actual, IndexVariable) + self.assertVariableIdentical(expected.to_index_variable(), actual) actual = as_variable(0) expected = Variable([], 0)