diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab531c00a48..52391c6008f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -69,6 +69,9 @@ Bug fixes - ``DataArray.to_masked_array`` always returns masked array with mask being an array (not a scalar value) (:issue:`684`) +- You can now pass pandas objects with non-numpy dtypes (e.g., ``categorical`` + or ``datetime64`` with a timezone) into xray without an error + (:issue:`716`). v0.6.2 (unreleased) ------------------- diff --git a/xray/core/indexing.py b/xray/core/indexing.py index 8efe34d920b..277428966a5 100644 --- a/xray/core/indexing.py +++ b/xray/core/indexing.py @@ -371,6 +371,10 @@ def __init__(self, array, dtype=None): # if a PeriodIndex, force an object dtype if isinstance(array, pd.PeriodIndex): dtype = np.dtype('O') + elif hasattr(array, 'categories'): + dtype = array.categories.dtype + elif not utils.is_valid_numpy_dtype(array.dtype): + dtype = np.dtype('O') else: dtype = array.dtype self._dtype = dtype @@ -387,7 +391,7 @@ def __array__(self, dtype=None): with suppress(AttributeError): # this might not be public API array = array.asobject - return array.values.astype(dtype) + return np.asarray(array, dtype) def __getitem__(self, key): if isinstance(key, tuple) and len(key) == 1: diff --git a/xray/core/utils.py b/xray/core/utils.py index e36893b87a1..660a69af211 100644 --- a/xray/core/utils.py +++ b/xray/core/utils.py @@ -169,6 +169,15 @@ def is_scalar(value): value is None) +def is_valid_numpy_dtype(dtype): + try: + np.dtype(dtype) + except (TypeError, ValueError): + return False + else: + return True + + def dict_equiv(first, second, compat=equivalent): """Test equivalence of two dict-like objects. If any of the values are numpy arrays, compare them correctly. diff --git a/xray/test/test_variable.py b/xray/test/test_variable.py index 25018cce8ce..272816ca432 100644 --- a/xray/test/test_variable.py +++ b/xray/test/test_variable.py @@ -5,6 +5,7 @@ from distutils.version import LooseVersion import numpy as np +import pytz import pandas as pd from xray import Variable, Dataset, DataArray @@ -153,7 +154,7 @@ def test_0d_time_data(self): def test_datetime64_conversion(self): times = pd.date_range('2000-01-01', periods=3) for values, preserve_source in [ - (times, False), + (times, True), (times.values, True), (times.values.astype('datetime64[s]'), False), (times.to_pydatetime(), False), @@ -163,15 +164,12 @@ def test_datetime64_conversion(self): self.assertArrayEqual(v.values, times.values) self.assertEqual(v.values.dtype, np.dtype('datetime64[ns]')) same_source = source_ndarray(v.values) is source_ndarray(values) - if preserve_source and self.cls is Variable: - self.assertTrue(same_source) - else: - self.assertFalse(same_source) + assert preserve_source == same_source def test_timedelta64_conversion(self): times = pd.timedelta_range(start=0, periods=3) for values, preserve_source in [ - (times, False), + (times, True), (times.values, True), (times.values.astype('timedelta64[s]'), False), (times.to_pytimedelta(), False), @@ -181,10 +179,7 @@ def test_timedelta64_conversion(self): self.assertArrayEqual(v.values, times.values) self.assertEqual(v.values.dtype, np.dtype('timedelta64[ns]')) same_source = source_ndarray(v.values) is source_ndarray(values) - if preserve_source and self.cls is Variable: - self.assertTrue(same_source) - else: - self.assertFalse(same_source) + assert preserve_source == same_source def test_object_conversion(self): data = np.arange(5).astype(str).astype(object) @@ -405,6 +400,22 @@ def test_aggregate_complex(self): expected = Variable((), 0.5 + 1j) self.assertVariableAllClose(v.mean(), expected) + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype='int64')) + v = self.cls('x', data) + print(v) # should not error + assert v.dtype == 'int64' + + def test_pandas_datetime64_with_tz(self): + data = pd.date_range(start='2000-01-01', + tz=pytz.timezone('America/New_York'), + periods=10, freq='1h') + v = self.cls('x', data) + print(v) # should not error + if 'America/New_York' in str(data.dtype): + # pandas is new enough that it has datetime64 with timezone dtype + assert v.dtype == 'object' + class TestVariable(TestCase, VariableSubclassTestCases): cls = staticmethod(Variable)