Skip to content

Commit d016ea7

Browse files
authored
Merge pull request #1657 from jhamman/fix/1652
fixes for warnings related to unit tests and nan comparisons
2 parents c58d142 + c2126fc commit d016ea7

22 files changed

+453
-418
lines changed

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ Breaking changes
8585
disk when calling ``repr`` (:issue:`1522`).
8686
By `Guido Imperiale <https://github.com/crusaderky>`_.
8787

88+
- Suppress ``RuntimeWarning`` issued by ``numpy`` for "invalid value comparisons"
89+
(e.g. NaNs). Xarray now behaves similarly to Pandas in its treatment of
90+
binary and unary operations on objects with ``NaN``s (:issue:`1657`).
91+
By `Joe Hamman <https://github.com/jhamman>`_.
92+
8893
- Several existing features have been deprecated and will change to new
8994
behavior in xarray v0.11. If you use any of them with xarray v0.10, you
9095
should see a ``FutureWarning`` that describes how to update your code:

xarray/core/dataarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1575,7 +1575,9 @@ def __array_wrap__(self, obj, context=None):
15751575
def _unary_op(f):
15761576
@functools.wraps(f)
15771577
def func(self, *args, **kwargs):
1578-
return self.__array_wrap__(f(self.variable.data, *args, **kwargs))
1578+
with np.errstate(all='ignore'):
1579+
return self.__array_wrap__(f(self.variable.data, *args,
1580+
**kwargs))
15791581

15801582
return func
15811583

xarray/core/pycompat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def itervalues(d):
2525
from functools import reduce
2626
import builtins
2727
from urllib.request import urlretrieve
28+
from inspect import getfullargspec as getargspec
2829
else: # pragma: no cover
2930
# Python 2
3031
basestring = basestring # noqa
@@ -43,7 +44,7 @@ def itervalues(d):
4344
reduce = reduce
4445
import __builtin__ as builtins
4546
from urllib import urlretrieve
46-
47+
from inspect import getargspec
4748
try:
4849
from cyordereddict import OrderedDict
4950
except ImportError: # pragma: no cover

xarray/core/variable.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,8 @@ def __array_wrap__(self, obj, context=None):
13541354
def _unary_op(f):
13551355
@functools.wraps(f)
13561356
def func(self, *args, **kwargs):
1357-
return self.__array_wrap__(f(self.data, *args, **kwargs))
1357+
with np.errstate(all='ignore'):
1358+
return self.__array_wrap__(f(self.data, *args, **kwargs))
13581359
return func
13591360

13601361
@staticmethod
@@ -1364,9 +1365,10 @@ def func(self, other):
13641365
if isinstance(other, (xr.DataArray, xr.Dataset)):
13651366
return NotImplemented
13661367
self_data, other_data, dims = _broadcast_compat_data(self, other)
1367-
new_data = (f(self_data, other_data)
1368-
if not reflexive
1369-
else f(other_data, self_data))
1368+
with np.errstate(all='ignore'):
1369+
new_data = (f(self_data, other_data)
1370+
if not reflexive
1371+
else f(other_data, self_data))
13701372
result = Variable(dims, new_data)
13711373
return result
13721374
return func
@@ -1381,7 +1383,8 @@ def func(self, other):
13811383
if dims != self.dims:
13821384
raise ValueError('dimensions cannot change for in-place '
13831385
'operations')
1384-
self.values = f(self_data, other_data)
1386+
with np.errstate(all='ignore'):
1387+
self.values = f(self_data, other_data)
13851388
return self
13861389
return func
13871390

xarray/plot/facetgrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from __future__ import division
33
from __future__ import print_function
44

5-
import inspect
65
import warnings
76
import itertools
87
import functools
98

109
import numpy as np
1110

11+
from ..core.pycompat import getargspec
1212
from ..core.formatting import format_item
1313
from .utils import (_determine_cmap_params, _infer_xy_labels,
1414
import_matplotlib_pyplot)
@@ -228,7 +228,7 @@ def map_dataarray(self, func, x, y, **kwargs):
228228
'filled': func.__name__ != 'contour',
229229
}
230230

231-
cmap_args = inspect.getargspec(_determine_cmap_params).args
231+
cmap_args = getargspec(_determine_cmap_params).args
232232
cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs)
233233

234234
cmap_params = _determine_cmap_params(**cmap_kwargs)

xarray/tests/test_accessors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pandas as pd
88

9-
from . import TestCase, requires_dask
9+
from . import TestCase, requires_dask, raises_regex
1010

1111

1212
class TestDatetimeAccessor(TestCase):
@@ -45,7 +45,7 @@ def test_not_datetime_type(self):
4545
nontime_data = self.data.copy()
4646
int_data = np.arange(len(self.data.time)).astype('int8')
4747
nontime_data['time'].values = int_data
48-
with self.assertRaisesRegexp(TypeError, 'dt'):
48+
with raises_regex(TypeError, 'dt'):
4949
nontime_data.time.dt
5050

5151
@requires_dask
@@ -93,4 +93,4 @@ def test_seasons(self):
9393
"SON", "SON", "SON", "DJF"]
9494
seasons = xr.DataArray(seasons)
9595

96-
self.assertArrayEqual(seasons.values, dates.dt.season.values)
96+
self.assertArrayEqual(seasons.values, dates.dt.season.values)

xarray/tests/test_backends.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf,
2929
requires_pynio, requires_pathlib, has_netCDF4, has_scipy,
3030
assert_allclose, flaky, network, requires_rasterio,
31-
assert_identical)
31+
assert_identical, raises_regex)
3232
from .test_dataset import create_test_data
3333

3434
from xarray.tests import mock
@@ -113,7 +113,7 @@ def __getitem__(self, key):
113113
return self.array[key]
114114

115115
array = UnreliableArray([0])
116-
with self.assertRaises(UnreliableArrayFailure):
116+
with pytest.raises(UnreliableArrayFailure):
117117
array[0]
118118
self.assertEqual(array[0], 0)
119119

@@ -218,7 +218,7 @@ def assert_loads(vars=None):
218218
self.assertTrue(v._in_memory)
219219
self.assertDatasetIdentical(expected, actual)
220220

221-
with self.assertRaises(AssertionError):
221+
with pytest.raises(AssertionError):
222222
# make sure the contextmanager works!
223223
with assert_loads() as ds:
224224
pass
@@ -345,8 +345,7 @@ def test_roundtrip_datetime_data(self):
345345
kwds = {'encoding': {'t0': {'units': 'days since 1950-01-01'}}}
346346
with self.roundtrip(expected, save_kwargs=kwds) as actual:
347347
self.assertDatasetIdentical(expected, actual)
348-
self.assertEquals(actual.t0.encoding['units'],
349-
'days since 1950-01-01')
348+
assert actual.t0.encoding['units'] == 'days since 1950-01-01'
350349

351350
def test_roundtrip_timedelta_data(self):
352351
time_deltas = pd.to_timedelta(['1h', '2h', 'NaT'])
@@ -528,7 +527,7 @@ def test_roundtrip_endian(self):
528527

529528
if type(self) is NetCDF4DataTest:
530529
ds['z'].encoding['endian'] = 'big'
531-
with self.assertRaises(NotImplementedError):
530+
with pytest.raises(NotImplementedError):
532531
with self.roundtrip(ds) as actual:
533532
pass
534533

@@ -539,7 +538,7 @@ def test_invalid_dataarray_names_raise(self):
539538
da = xr.DataArray(data)
540539
for name, e in zip([0, (4, 5), True, ''], [te, te, te, ve]):
541540
ds = Dataset({name: da})
542-
with self.assertRaisesRegexp(*e):
541+
with raises_regex(*e):
543542
with self.roundtrip(ds) as actual:
544543
pass
545544

@@ -551,17 +550,17 @@ def test_encoding_kwarg(self):
551550
self.assertEqual(ds.x.encoding, {})
552551

553552
kwargs = dict(encoding={'x': {'foo': 'bar'}})
554-
with self.assertRaisesRegexp(ValueError, 'unexpected encoding'):
553+
with raises_regex(ValueError, 'unexpected encoding'):
555554
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
556555
pass
557556

558557
kwargs = dict(encoding={'x': 'foo'})
559-
with self.assertRaisesRegexp(ValueError, 'must be castable'):
558+
with raises_regex(ValueError, 'must be castable'):
560559
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
561560
pass
562561

563562
kwargs = dict(encoding={'invalid': {}})
564-
with self.assertRaises(KeyError):
563+
with pytest.raises(KeyError):
565564
with self.roundtrip(ds, save_kwargs=kwargs) as actual:
566565
pass
567566

@@ -676,9 +675,9 @@ def test_open_group(self):
676675
self.assertVariableEqual(actual['x'], expected['x'])
677676

678677
# check that missing group raises appropriate exception
679-
with self.assertRaises(IOError):
678+
with pytest.raises(IOError):
680679
open_dataset(tmp_file, group='bar')
681-
with self.assertRaisesRegexp(ValueError, 'must be a string'):
680+
with raises_regex(ValueError, 'must be a string'):
682681
open_dataset(tmp_file, group=(1, 2, 3))
683682

684683
def test_open_subgroup(self):
@@ -1019,7 +1018,7 @@ def create_store(self):
10191018

10201019
def test_array_attrs(self):
10211020
ds = Dataset(attrs={'foo': [[1, 2], [3, 4]]})
1022-
with self.assertRaisesRegexp(ValueError, 'must be 1-dimensional'):
1021+
with raises_regex(ValueError, 'must be 1-dimensional'):
10231022
with self.roundtrip(ds) as roundtripped:
10241023
pass
10251024

@@ -1036,11 +1035,11 @@ def test_netcdf3_endianness(self):
10361035

10371036
@requires_netCDF4
10381037
def test_nc4_scipy(self):
1039-
with create_tmp_file() as tmp_file:
1038+
with create_tmp_file(allow_cleanup_failure=True) as tmp_file:
10401039
with nc4.Dataset(tmp_file, 'w', format='NETCDF4') as rootgrp:
10411040
rootgrp.createGroup('foo')
10421041

1043-
with self.assertRaisesRegexp(TypeError, 'pip install netcdf4'):
1042+
with raises_regex(TypeError, 'pip install netcdf4'):
10441043
open_dataset(tmp_file, engine='scipy')
10451044

10461045

@@ -1096,18 +1095,18 @@ def test_write_store(self):
10961095

10971096
def test_engine(self):
10981097
data = create_test_data()
1099-
with self.assertRaisesRegexp(ValueError, 'unrecognized engine'):
1098+
with raises_regex(ValueError, 'unrecognized engine'):
11001099
data.to_netcdf('foo.nc', engine='foobar')
1101-
with self.assertRaisesRegexp(ValueError, 'invalid engine'):
1100+
with raises_regex(ValueError, 'invalid engine'):
11021101
data.to_netcdf(engine='netcdf4')
11031102

11041103
with create_tmp_file() as tmp_file:
11051104
data.to_netcdf(tmp_file)
1106-
with self.assertRaisesRegexp(ValueError, 'unrecognized engine'):
1105+
with raises_regex(ValueError, 'unrecognized engine'):
11071106
open_dataset(tmp_file, engine='foobar')
11081107

11091108
netcdf_bytes = data.to_netcdf()
1110-
with self.assertRaisesRegexp(ValueError, 'can only read'):
1109+
with raises_regex(ValueError, 'can only read'):
11111110
open_dataset(BytesIO(netcdf_bytes), engine='foobar')
11121111

11131112
def test_cross_engine_read_write_netcdf3(self):
@@ -1389,12 +1388,12 @@ def test_common_coord_when_datavars_minimal(self):
13891388
def test_invalid_data_vars_value_should_fail(self):
13901389

13911390
with self.setup_files_and_datasets() as (files, _):
1392-
with self.assertRaises(ValueError):
1391+
with pytest.raises(ValueError):
13931392
with open_mfdataset(files, data_vars='minimum'):
13941393
pass
13951394

13961395
# test invalid coord parameter
1397-
with self.assertRaises(ValueError):
1396+
with pytest.raises(ValueError):
13981397
with open_mfdataset(files, coords='minimum'):
13991398
pass
14001399

@@ -1452,7 +1451,7 @@ def test_open_mfdataset(self):
14521451
self.assertEqual(actual.foo.variable.data.chunks,
14531452
((3, 2, 3, 2),))
14541453

1455-
with self.assertRaisesRegexp(IOError, 'no files to open'):
1454+
with raises_regex(IOError, 'no files to open'):
14561455
open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose)
14571456

14581457
@requires_pathlib
@@ -1483,7 +1482,7 @@ def test_attrs_mfdataset(self):
14831482
# first dataset loaded
14841483
self.assertEqual(actual.test1, ds1.test1)
14851484
# attributes from ds2 are not retained, e.g.,
1486-
with self.assertRaisesRegexp(AttributeError,
1485+
with raises_regex(AttributeError,
14871486
'no attribute'):
14881487
actual.test2
14891488

@@ -1513,15 +1512,15 @@ def test_save_mfdataset_roundtrip(self):
15131512

15141513
def test_save_mfdataset_invalid(self):
15151514
ds = Dataset()
1516-
with self.assertRaisesRegexp(ValueError, 'cannot use mode'):
1515+
with raises_regex(ValueError, 'cannot use mode'):
15171516
save_mfdataset([ds, ds], ['same', 'same'])
1518-
with self.assertRaisesRegexp(ValueError, 'same length'):
1517+
with raises_regex(ValueError, 'same length'):
15191518
save_mfdataset([ds, ds], ['only one path'])
15201519

15211520
def test_save_mfdataset_invalid_dataarray(self):
15221521
# regression test for GH1555
15231522
da = DataArray([1, 2])
1524-
with self.assertRaisesRegexp(TypeError, 'supports writing Dataset'):
1523+
with raises_regex(TypeError, 'supports writing Dataset'):
15251524
save_mfdataset([da], ['dataarray'])
15261525

15271526

@@ -1846,11 +1845,11 @@ def test_indexing(self):
18461845
# but on x and y only windowed operations are allowed, more
18471846
# exotic slicing should raise an error
18481847
err_msg = 'not valid on rasterio'
1849-
with self.assertRaisesRegexp(IndexError, err_msg):
1848+
with raises_regex(IndexError, err_msg):
18501849
actual.isel(x=[2, 4], y=[1, 3]).values
1851-
with self.assertRaisesRegexp(IndexError, err_msg):
1850+
with raises_regex(IndexError, err_msg):
18521851
actual.isel(x=[4, 2]).values
1853-
with self.assertRaisesRegexp(IndexError, err_msg):
1852+
with raises_regex(IndexError, err_msg):
18541853
actual.isel(x=slice(5, 2, -1)).values
18551854

18561855
# Integer indexing
@@ -1916,7 +1915,7 @@ def test_caching(self):
19161915

19171916
# Without cache an error is raised
19181917
err_msg = 'not valid on rasterio'
1919-
with self.assertRaisesRegexp(IndexError, err_msg):
1918+
with raises_regex(IndexError, err_msg):
19201919
actual.isel(x=[2, 4]).values
19211920

19221921
# This should cache everything
@@ -1976,7 +1975,7 @@ class TestEncodingInvalid(TestCase):
19761975

19771976
def test_extract_nc4_variable_encoding(self):
19781977
var = xr.Variable(('x',), [1, 2, 3], {}, {'foo': 'bar'})
1979-
with self.assertRaisesRegexp(ValueError, 'unexpected encoding'):
1978+
with raises_regex(ValueError, 'unexpected encoding'):
19801979
_extract_nc4_variable_encoding(var, raise_on_invalid=True)
19811980

19821981
var = xr.Variable(('x',), [1, 2, 3], {}, {'chunking': (2, 1)})
@@ -1992,7 +1991,7 @@ def test_extract_h5nc_encoding(self):
19921991
# not supported with h5netcdf (yet)
19931992
var = xr.Variable(('x',), [1, 2, 3], {},
19941993
{'least_sigificant_digit': 2})
1995-
with self.assertRaisesRegexp(ValueError, 'unexpected encoding'):
1994+
with raises_regex(ValueError, 'unexpected encoding'):
19961995
_extract_nc4_variable_encoding(var, raise_on_invalid=True)
19971996

19981997

@@ -2025,17 +2024,17 @@ def new_dataset_and_coord_attrs():
20252024
ds, attrs = new_dataset_and_attrs()
20262025

20272026
attrs[123] = 'test'
2028-
with self.assertRaisesRegexp(TypeError, 'Invalid name for attr'):
2027+
with raises_regex(TypeError, 'Invalid name for attr'):
20292028
ds.to_netcdf('test.nc')
20302029

20312030
ds, attrs = new_dataset_and_attrs()
20322031
attrs[MiscObject()] = 'test'
2033-
with self.assertRaisesRegexp(TypeError, 'Invalid name for attr'):
2032+
with raises_regex(TypeError, 'Invalid name for attr'):
20342033
ds.to_netcdf('test.nc')
20352034

20362035
ds, attrs = new_dataset_and_attrs()
20372036
attrs[''] = 'test'
2038-
with self.assertRaisesRegexp(ValueError, 'Invalid name for attr'):
2037+
with raises_regex(ValueError, 'Invalid name for attr'):
20392038
ds.to_netcdf('test.nc')
20402039

20412040
# This one should work
@@ -2046,12 +2045,12 @@ def new_dataset_and_coord_attrs():
20462045

20472046
ds, attrs = new_dataset_and_attrs()
20482047
attrs['test'] = {'a': 5}
2049-
with self.assertRaisesRegexp(TypeError, 'Invalid value for attr'):
2048+
with raises_regex(TypeError, 'Invalid value for attr'):
20502049
ds.to_netcdf('test.nc')
20512050

20522051
ds, attrs = new_dataset_and_attrs()
20532052
attrs['test'] = MiscObject()
2054-
with self.assertRaisesRegexp(TypeError, 'Invalid value for attr'):
2053+
with raises_regex(TypeError, 'Invalid value for attr'):
20552054
ds.to_netcdf('test.nc')
20562055

20572056
ds, attrs = new_dataset_and_attrs()

0 commit comments

Comments
 (0)