Skip to content

Commit 58b42af

Browse files
committed
Handle non-numpy dtypes without erroring
Fixes GH716
1 parent 4be7088 commit 58b42af

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ Bug fixes
6969

7070
- ``DataArray.to_masked_array`` always returns masked array with mask being an array
7171
(not a scalar value) (:issue:`684`)
72+
- You can now pass pandas objects with non-numpy dtypes (e.g., ``categorical``
73+
or ``datetime64`` with a timezone) into xray without an error
74+
(:issue:`716`).
7275

7376
v0.6.2 (unreleased)
7477
-------------------

xray/core/indexing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ def __init__(self, array, dtype=None):
371371
# if a PeriodIndex, force an object dtype
372372
if isinstance(array, pd.PeriodIndex):
373373
dtype = np.dtype('O')
374+
elif hasattr(array, 'categories'):
375+
dtype = array.categories.dtype
376+
elif not utils.is_valid_numpy_dtype(array.dtype):
377+
dtype = np.dtype('O')
374378
else:
375379
dtype = array.dtype
376380
self._dtype = dtype
@@ -387,7 +391,7 @@ def __array__(self, dtype=None):
387391
with suppress(AttributeError):
388392
# this might not be public API
389393
array = array.asobject
390-
return array.values.astype(dtype)
394+
return np.asarray(array, dtype)
391395

392396
def __getitem__(self, key):
393397
if isinstance(key, tuple) and len(key) == 1:

xray/core/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ def is_scalar(value):
169169
value is None)
170170

171171

172+
def is_valid_numpy_dtype(dtype):
173+
try:
174+
np.dtype(dtype)
175+
except (TypeError, ValueError):
176+
return False
177+
else:
178+
return True
179+
180+
172181
def dict_equiv(first, second, compat=equivalent):
173182
"""Test equivalence of two dict-like objects. If any of the values are
174183
numpy arrays, compare them correctly.

xray/test/test_variable.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from distutils.version import LooseVersion
77
import numpy as np
8+
import pytz
89
import pandas as pd
910

1011
from xray import Variable, Dataset, DataArray
@@ -153,7 +154,7 @@ def test_0d_time_data(self):
153154
def test_datetime64_conversion(self):
154155
times = pd.date_range('2000-01-01', periods=3)
155156
for values, preserve_source in [
156-
(times, False),
157+
(times, True),
157158
(times.values, True),
158159
(times.values.astype('datetime64[s]'), False),
159160
(times.to_pydatetime(), False),
@@ -163,15 +164,12 @@ def test_datetime64_conversion(self):
163164
self.assertArrayEqual(v.values, times.values)
164165
self.assertEqual(v.values.dtype, np.dtype('datetime64[ns]'))
165166
same_source = source_ndarray(v.values) is source_ndarray(values)
166-
if preserve_source and self.cls is Variable:
167-
self.assertTrue(same_source)
168-
else:
169-
self.assertFalse(same_source)
167+
assert preserve_source == same_source
170168

171169
def test_timedelta64_conversion(self):
172170
times = pd.timedelta_range(start=0, periods=3)
173171
for values, preserve_source in [
174-
(times, False),
172+
(times, True),
175173
(times.values, True),
176174
(times.values.astype('timedelta64[s]'), False),
177175
(times.to_pytimedelta(), False),
@@ -181,10 +179,7 @@ def test_timedelta64_conversion(self):
181179
self.assertArrayEqual(v.values, times.values)
182180
self.assertEqual(v.values.dtype, np.dtype('timedelta64[ns]'))
183181
same_source = source_ndarray(v.values) is source_ndarray(values)
184-
if preserve_source and self.cls is Variable:
185-
self.assertTrue(same_source)
186-
else:
187-
self.assertFalse(same_source)
182+
assert preserve_source == same_source
188183

189184
def test_object_conversion(self):
190185
data = np.arange(5).astype(str).astype(object)
@@ -405,6 +400,22 @@ def test_aggregate_complex(self):
405400
expected = Variable((), 0.5 + 1j)
406401
self.assertVariableAllClose(v.mean(), expected)
407402

403+
def test_pandas_cateogrical_dtype(self):
404+
data = pd.Categorical(np.arange(10, dtype='int64'))
405+
v = self.cls('x', data)
406+
print(v) # should not error
407+
assert v.dtype == 'int64'
408+
409+
def test_pandas_datetime64_with_tz(self):
410+
data = pd.date_range(start='2000-01-01',
411+
tz=pytz.timezone('America/New_York'),
412+
periods=10, freq='1h')
413+
v = self.cls('x', data)
414+
print(v) # should not error
415+
if 'America/New_York' in str(data.dtype):
416+
# pandas is new enough that it has datetime64 with timezone dtype
417+
assert v.dtype == 'object'
418+
408419

409420
class TestVariable(TestCase, VariableSubclassTestCases):
410421
cls = staticmethod(Variable)

0 commit comments

Comments
 (0)