Skip to content

[WIP] ENH: add to_xarray conversion method #11950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements-2.7.run
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ html5lib=1.0b2
beautiful-soup=4.2.1
statsmodels
jinja2=2.8
xray
1 change: 1 addition & 0 deletions ci/requirements-3.5.run
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ html5lib
lxml
matplotlib
jinja2
xray

# currently causing some warnings
#sqlalchemy
Expand Down
1 change: 1 addition & 0 deletions doc/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Optional Dependencies
* `Cython <http://www.cython.org>`__: Only necessary to build development
version. Version 0.19.1 or higher.
* `SciPy <http://www.scipy.org>`__: miscellaneous statistical functions
* `xray <http://xray.readthedocs.org>`__: pandas like handling for > 2 dims.
* `PyTables <http://www.pytables.org>`__: necessary for HDF5-based storage. Version 3.0.0 or higher required, Version 3.2.1 or higher highly recommended.
* `SQLAlchemy <http://www.sqlalchemy.org>`__: for SQL database support. Version 0.8.1 or higher recommended.
* Besides SQLAlchemy, you also need a database specific driver.
Expand Down
24 changes: 23 additions & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,29 @@ def to_clipboard(self, excel=None, sep=None, **kwargs):
from pandas.io import clipboard
clipboard.to_clipboard(self, excel=excel, sep=sep, **kwargs)

#----------------------------------------------------------------------
def to_xarray(self):
"""
Return an xarray object from the pandas object.

Returns
-------
a DataArray for a Series
a Dataset for a DataFrame
a Dataset for higher dims
"""
import xray
if self.ndim == 1:
return xray.DataArray.from_series(self)
elif self.ndim == 2:
return xray.Dataset.from_dataframe(self)

# > 2 dims
coords = [(a, self._get_axis(a)) for a in self._AXIS_ORDERS]
return xray.DataArray(self,
coords=coords,
).to_dataset()

# ----------------------------------------------------------------------
# Fancy Indexing

@classmethod
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5457,10 +5457,10 @@ def test_repr_column_name_unicode_truncation_bug(self):
def test_head_tail(self):
assert_frame_equal(self.frame.head(), self.frame[:5])
assert_frame_equal(self.frame.tail(), self.frame[-5:])

assert_frame_equal(self.frame.head(0), self.frame[0:0])
assert_frame_equal(self.frame.tail(0), self.frame[0:0])

assert_frame_equal(self.frame.head(-1), self.frame[:-1])
assert_frame_equal(self.frame.tail(-1), self.frame[1:])
assert_frame_equal(self.frame.head(1), self.frame[:1])
Expand Down Expand Up @@ -13564,10 +13564,11 @@ def test_round_issue(self):

decimals = pd.Series([1, 0, 2], index=['A', 'B', 'A'])
self.assertRaises(ValueError, df.round, decimals)

def test_built_in_round(self):
if not compat.PY3:
raise nose.SkipTest('build in round cannot be overriden prior to Python 3')
raise nose.SkipTest('build in round cannot be '
'overriden prior to Python 3')

# GH11763
# Here's the test frame we'll be working with
Expand All @@ -13578,7 +13579,7 @@ def test_built_in_round(self):
expected_rounded = DataFrame(
{'col1': [1., 2., 3.], 'col2': [1., 2., 3.]})
tm.assert_frame_equal(round(df), expected_rounded)

def test_quantile(self):
from numpy import percentile

Expand Down
157 changes: 154 additions & 3 deletions pandas/tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
# pylint: disable-msg=E1101,W0612

from distutils.version import LooseVersion
from datetime import datetime, timedelta
import nose
import numpy as np
from numpy import nan
import pandas as pd

from pandas import (Index, Series, DataFrame, Panel,
isnull, notnull, date_range, period_range)
from pandas.core.index import Index, MultiIndex
from pandas import (Index, Series, DataFrame, Panel, Panel4D,
isnull, notnull, date_range, period_range,
MultiIndex)

import pandas.core.common as com

Expand All @@ -18,6 +19,7 @@
from pandas.util.testing import (assert_series_equal,
assert_frame_equal,
assert_panel_equal,
assert_panel4d_equal,
assert_almost_equal,
assert_equal,
ensure_clean)
Expand Down Expand Up @@ -1032,6 +1034,55 @@ def test_describe_none(self):
expected = Series([0, 0], index=['count', 'unique'], name='None')
assert_series_equal(noneSeries.describe(), expected)

def test_to_xarray(self):

tm._skip_if_no_xarray()
import xray
from xray import DataArray

if LooseVersion(xray.__version__) > '0.6.1':
# https://github.com/xray/xray/issues/697
s = Series([])
s.index.name = 'foo'
result = s.to_xarray()
self.assertEqual(len(result), 0)
self.assertEqual(len(result.coords), 1)
self.assertEqual(result.coords.keys(), ['foo'])
self.assertIsInstance(result, DataArray)

for index in [tm.makeFloatIndex, tm.makeIntIndex,
tm.makeStringIndex, tm.makeUnicodeIndex,
tm.makeDateIndex, tm.makePeriodIndex,
tm.makeTimedeltaIndex]:
s = Series(range(6), index=index(6))
s.index.name = 'foo'
result = s.to_xarray()
repr(result)
self.assertEqual(len(result), 6)
self.assertEqual(len(result.coords), 1)
assert_almost_equal(result.coords.keys(), ['foo'])
self.assertIsInstance(result, DataArray)

# idempotency
assert_series_equal(result.to_series(), s)

# fails ATM
# https://github.com/xray/xray/issues/700
for index in [tm.makeCategoricalIndex]:
s = Series(range(6), index=index(6))
s.index.name = 'foo'

result = s.to_xarray()
self.assertRaises(ValueError, lambda: repr(result))

s.index = pd.MultiIndex.from_product([['a', 'b'], range(3)],
names=['one', 'two'])
result = s.to_xarray()
self.assertEqual(len(result), 2)
assert_almost_equal(result.coords.keys(), ['one', 'two'])
self.assertIsInstance(result, DataArray)
assert_series_equal(result.to_series(), s)


class TestDataFrame(tm.TestCase, Generic):
_typ = DataFrame
Expand Down Expand Up @@ -1715,10 +1766,110 @@ def test_pct_change(self):

self.assert_frame_equal(result, expected)

def test_to_xarray(self):

tm._skip_if_no_xarray()
import xray
from xray import Dataset

df = DataFrame({'a': list('abc'),
'b': list(range(1, 4)),
'c': np.arange(3, 6).astype('u1'),
'd': np.arange(4.0, 7.0, dtype='float64'),
'e': [True, False, True],
'f': pd.Categorical(list('abc')),
'g': pd.date_range('20130101', periods=3),
'h': pd.date_range('20130101',
periods=3,
tz='US/Eastern')}
)

if LooseVersion(xray.__version__) > '0.6.1':
# https://github.com/pydata/xarray/issues/697
df.index.name = 'foo'
result = df[0:0].to_xarray()
self.assertEqual(result.dims['foo'], 0)
self.assertIsInstance(result, Dataset)

for index in [tm.makeFloatIndex, tm.makeIntIndex,
tm.makeStringIndex, tm.makeUnicodeIndex,
tm.makeDateIndex, tm.makePeriodIndex,
tm.makeCategoricalIndex, tm.makeTimedeltaIndex]:
df.index = index(3)
df.index.name = 'foo'
df.columns.name = 'bar'
result = df.to_xarray()
self.assertEqual(result.dims['foo'], 3)
self.assertEqual(len(result.coords), 1)
self.assertEqual(len(result.data_vars), 8)
assert_almost_equal(result.coords.keys(), ['foo'])
self.assertIsInstance(result, Dataset)

# idempotency
# categoricals are not preserved
# datetimes w/tz are not preserved
# column names are lost
expected = df.copy()
expected['f'] = expected['f'].astype(object)
expected['h'] = expected['h'].astype('datetime64[ns]')
expected.columns.name = None
assert_frame_equal(result.to_dataframe(),
expected,
check_index_type=False)

# not implemented
df.index = pd.MultiIndex.from_product([['a'], range(3)],
names=['one', 'two'])
self.assertRaises(ValueError, lambda: df.to_xarray())


class TestPanel(tm.TestCase, Generic):
_typ = Panel
_comparator = lambda self, x, y: assert_panel_equal(x, y)

def test_to_xarray(self):

tm._skip_if_no_xarray()
import xray
from xray import Dataset

p = tm.makePanel()

if LooseVersion(xray.__version__) > '0.6.1':
# https://github.com/pydata/xarray/issues/697
pass

result = p.to_xarray()
self.assertIsInstance(result, Dataset)
self.assertEqual(len(result.coords), 3)
assert_almost_equal(result.coords.keys(),
['items', 'major_axis', 'minor_axis'])
self.assertEqual(len(result.dims), 3)


class TestPanel4D(tm.TestCase, Generic):
_typ = Panel4D
_comparator = lambda self, x, y: assert_panel4d_equal(x, y)

def test_to_xarray(self):

tm._skip_if_no_xarray()
import xray
from xray import Dataset

p = tm.makePanel4D()

if LooseVersion(xray.__version__) > '0.6.1':
# https://github.com/xray/xray/issues/697
pass

result = p.to_xarray()
self.assertIsInstance(result, Dataset)
self.assertEqual(len(result.coords), 4)
assert_almost_equal(result.coords.keys(),
['labels', 'items', 'major_axis', 'minor_axis'])
self.assertEqual(len(result.dims), 4)


class TestNDFrame(tm.TestCase):
# tests that don't fit elsewhere
Expand Down
1 change: 1 addition & 0 deletions pandas/util/print_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def show_versions(as_json=False):
("numpy", lambda mod: mod.version.version),
("scipy", lambda mod: mod.version.version),
("statsmodels", lambda mod: mod.__version__),
("xray", lambda mod: mod.__version__),
("IPython", lambda mod: mod.__version__),
("sphinx", lambda mod: mod.__version__),
("patsy", lambda mod: mod.__version__),
Expand Down
8 changes: 8 additions & 0 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ def _skip_if_no_scipy():
raise nose.SkipTest('scipy.interpolate missing')


def _skip_if_no_xarray():
try:
import xray
except ImportError:
import nose
raise nose.SkipTest("xarray not installed")


def _skip_if_no_pytz():
try:
import pytz
Expand Down