Skip to content

Commit 3f4b025

Browse files
shoyercrusaderky
authored andcommitted
sparse=True option for from_dataframe and from_series (#3210)
sparse=True option for from_dataframe and from_series Fixes #3206
1 parent 851f763 commit 3f4b025

File tree

7 files changed

+148
-29
lines changed

7 files changed

+148
-29
lines changed

doc/whats-new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ New functions/methods
4343
This requires `sparse>=0.8.0`. By `Nezar Abdennur <https://github.com/nvictus>`_
4444
and `Guido Imperiale <https://github.com/crusaderky>`_.
4545

46+
- :py:meth:`~Dataset.from_dataframe` and :py:meth:`~DataArray.from_series` now
47+
support ``sparse=True`` for converting pandas objects into xarray objects
48+
wrapping sparse arrays. This is particularly useful with sparsely populated
49+
hierarchical indexes. (:issue:`3206`)
50+
By `Stephan Hoyer <https://github.com/shoyer>`_.
51+
4652
- The xarray package is now discoverable by mypy (although typing hints coverage is not
4753
complete yet). mypy type checking is now enforced by CI. Libraries that depend on
4854
xarray and use mypy can now remove from their setup.cfg the lines::

xarray/core/dataarray.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def reset_coords(
733733
else:
734734
if self.name is None:
735735
raise ValueError(
736-
"cannot reset_coords with drop=False " "on an unnamed DataArrray"
736+
"cannot reset_coords with drop=False on an unnamed DataArrray"
737737
)
738738
dataset[self.name] = self.variable
739739
return dataset
@@ -1448,9 +1448,7 @@ def expand_dims(
14481448
This object, but with an additional dimension(s).
14491449
"""
14501450
if isinstance(dim, int):
1451-
raise TypeError(
1452-
"dim should be hashable or sequence/mapping of " "hashables"
1453-
)
1451+
raise TypeError("dim should be hashable or sequence/mapping of hashables")
14541452
elif isinstance(dim, Sequence) and not isinstance(dim, str):
14551453
if len(dim) != len(set(dim)):
14561454
raise ValueError("dims should not contain duplicate values.")
@@ -2277,19 +2275,27 @@ def from_dict(cls, d: dict) -> "DataArray":
22772275
return obj
22782276

22792277
@classmethod
2280-
def from_series(cls, series: pd.Series) -> "DataArray":
2278+
def from_series(cls, series: pd.Series, sparse: bool = False) -> "DataArray":
22812279
"""Convert a pandas.Series into an xarray.DataArray.
22822280
22832281
If the series's index is a MultiIndex, it will be expanded into a
22842282
tensor product of one-dimensional coordinates (filling in missing
22852283
values with NaN). Thus this operation should be the inverse of the
22862284
`to_series` method.
2285+
2286+
If sparse=True, creates a sparse array instead of a dense NumPy array.
2287+
Requires the pydata/sparse package.
2288+
2289+
See also
2290+
--------
2291+
xarray.Dataset.from_dataframe
22872292
"""
2288-
# TODO: add a 'name' parameter
2289-
name = series.name
2290-
df = pd.DataFrame({name: series})
2291-
ds = Dataset.from_dataframe(df)
2292-
return ds[name]
2293+
temp_name = "__temporary_name"
2294+
df = pd.DataFrame({temp_name: series})
2295+
ds = Dataset.from_dataframe(df, sparse=sparse)
2296+
result = cast(DataArray, ds[temp_name])
2297+
result.name = series.name
2298+
return result
22932299

22942300
def to_cdms2(self) -> "cdms2_Variable":
22952301
"""Convert this array into a cdms2.Variable
@@ -2704,7 +2710,7 @@ def dot(
27042710
"""
27052711
if isinstance(other, Dataset):
27062712
raise NotImplementedError(
2707-
"dot products are not yet supported " "with Dataset objects."
2713+
"dot products are not yet supported with Dataset objects."
27082714
)
27092715
if not isinstance(other, DataArray):
27102716
raise TypeError("dot only operates on DataArrays.")

xarray/core/dataset.py

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,12 +1214,13 @@ def loc(self) -> _LocIndexer:
12141214
"""
12151215
return _LocIndexer(self)
12161216

1217-
def __getitem__(self, key: object) -> "Union[DataArray, Dataset]":
1217+
def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]":
12181218
"""Access variables or coordinates this dataset as a
12191219
:py:class:`~xarray.DataArray`.
12201220
12211221
Indexing with a list of names will return a new ``Dataset`` object.
12221222
"""
1223+
# TODO(shoyer): type this properly: https://github.com/python/mypy/issues/7328
12231224
if utils.is_dict_like(key):
12241225
return self.isel(**cast(Mapping, key))
12251226

@@ -3916,8 +3917,61 @@ def to_dataframe(self):
39163917
"""
39173918
return self._to_dataframe(self.dims)
39183919

3920+
def _set_sparse_data_from_dataframe(
3921+
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
3922+
) -> None:
3923+
from sparse import COO
3924+
3925+
idx = dataframe.index
3926+
if isinstance(idx, pd.MultiIndex):
3927+
try:
3928+
codes = idx.codes
3929+
except AttributeError:
3930+
# deprecated since pandas 0.24
3931+
codes = idx.labels
3932+
coords = np.stack([np.asarray(code) for code in codes], axis=0)
3933+
is_sorted = idx.is_lexsorted
3934+
else:
3935+
coords = np.arange(idx.size).reshape(1, -1)
3936+
is_sorted = True
3937+
3938+
for name, series in dataframe.items():
3939+
# Cast to a NumPy array first, in case the Series is a pandas
3940+
# Extension array (which doesn't have a valid NumPy dtype)
3941+
values = np.asarray(series)
3942+
3943+
# In virtually all real use cases, the sparse array will now have
3944+
# missing values and needs a fill_value. For consistency, don't
3945+
# special case the rare exceptions (e.g., dtype=int without a
3946+
# MultiIndex).
3947+
dtype, fill_value = dtypes.maybe_promote(values.dtype)
3948+
values = np.asarray(values, dtype=dtype)
3949+
3950+
data = COO(
3951+
coords,
3952+
values,
3953+
shape,
3954+
has_duplicates=False,
3955+
sorted=is_sorted,
3956+
fill_value=fill_value,
3957+
)
3958+
self[name] = (dims, data)
3959+
3960+
def _set_numpy_data_from_dataframe(
3961+
self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...]
3962+
) -> None:
3963+
idx = dataframe.index
3964+
if isinstance(idx, pd.MultiIndex):
3965+
# expand the DataFrame to include the product of all levels
3966+
full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names)
3967+
dataframe = dataframe.reindex(full_idx)
3968+
3969+
for name, series in dataframe.items():
3970+
data = np.asarray(series).reshape(shape)
3971+
self[name] = (dims, data)
3972+
39193973
@classmethod
3920-
def from_dataframe(cls, dataframe):
3974+
def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Dataset":
39213975
"""Convert a pandas.DataFrame into an xarray.Dataset
39223976
39233977
Each column will be converted into an independent variable in the
@@ -3926,7 +3980,24 @@ def from_dataframe(cls, dataframe):
39263980
values with NaN). This method will produce a Dataset very similar to
39273981
that on which the 'to_dataframe' method was called, except with
39283982
possibly redundant dimensions (since all dataset variables will have
3929-
the same dimensionality).
3983+
the same dimensionality)
3984+
3985+
Parameters
3986+
----------
3987+
dataframe : pandas.DataFrame
3988+
DataFrame from which to copy data and indices.
3989+
sparse : bool
3990+
If true, create a sparse arrays instead of dense numpy arrays. This
3991+
can potentially save a large amount of memory if the DataFrame has
3992+
a MultiIndex. Requires the sparse package (sparse.pydata.org).
3993+
3994+
Returns
3995+
-------
3996+
New Dataset.
3997+
3998+
See also
3999+
--------
4000+
xarray.DataArray.from_series
39304001
"""
39314002
# TODO: Add an option to remove dimensions along which the variables
39324003
# are constant, to enable consistent serialization to/from a dataframe,
@@ -3939,25 +4010,23 @@ def from_dataframe(cls, dataframe):
39394010
obj = cls()
39404011

39414012
if isinstance(idx, pd.MultiIndex):
3942-
# it's a multi-index
3943-
# expand the DataFrame to include the product of all levels
3944-
full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names)
3945-
dataframe = dataframe.reindex(full_idx)
3946-
dims = [
4013+
dims = tuple(
39474014
name if name is not None else "level_%i" % n
39484015
for n, name in enumerate(idx.names)
3949-
]
4016+
)
39504017
for dim, lev in zip(dims, idx.levels):
39514018
obj[dim] = (dim, lev)
3952-
shape = [lev.size for lev in idx.levels]
4019+
shape = tuple(lev.size for lev in idx.levels)
39534020
else:
3954-
dims = (idx.name if idx.name is not None else "index",)
3955-
obj[dims[0]] = (dims, idx)
3956-
shape = -1
4021+
index_name = idx.name if idx.name is not None else "index"
4022+
dims = (index_name,)
4023+
obj[index_name] = (dims, idx)
4024+
shape = (idx.size,)
39574025

3958-
for name, series in dataframe.items():
3959-
data = np.asarray(series).reshape(shape)
3960-
obj[name] = (dims, data)
4026+
if sparse:
4027+
obj._set_sparse_data_from_dataframe(dataframe, dims, shape)
4028+
else:
4029+
obj._set_numpy_data_from_dataframe(dataframe, dims, shape)
39614030
return obj
39624031

39634032
def to_dask_dataframe(self, dim_order=None, set_index=False):

xarray/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def LooseVersion(vstring):
8484
has_iris, requires_iris = _importorskip("iris")
8585
has_cfgrib, requires_cfgrib = _importorskip("cfgrib")
8686
has_numbagg, requires_numbagg = _importorskip("numbagg")
87+
has_sparse, requires_sparse = _importorskip("sparse")
8788

8889
# some special cases
8990
has_h5netcdf07, requires_h5netcdf07 = _importorskip("h5netcdf", minversion="0.7")

xarray/tests/test_dataarray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
requires_np113,
3030
requires_numbagg,
3131
requires_scipy,
32+
requires_sparse,
3233
source_ndarray,
3334
)
3435

@@ -3398,6 +3399,19 @@ def test_to_and_from_series(self):
33983399
expected_da = self.dv.rename(None)
33993400
assert_identical(expected_da, DataArray.from_series(actual).drop(["x", "y"]))
34003401

3402+
@requires_sparse
3403+
def test_from_series_sparse(self):
3404+
import sparse
3405+
3406+
series = pd.Series([1, 2], index=[("a", 1), ("b", 2)])
3407+
3408+
actual_sparse = DataArray.from_series(series, sparse=True)
3409+
actual_dense = DataArray.from_series(series, sparse=False)
3410+
3411+
assert isinstance(actual_sparse.data, sparse.COO)
3412+
actual_sparse.data = actual_sparse.data.todense()
3413+
assert_identical(actual_sparse, actual_dense)
3414+
34013415
def test_to_and_from_empty_series(self):
34023416
# GH697
34033417
expected = pd.Series([])

xarray/tests/test_dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
requires_dask,
4747
requires_numbagg,
4848
requires_scipy,
49+
requires_sparse,
4950
source_ndarray,
5051
)
5152

@@ -3653,6 +3654,28 @@ def test_to_and_from_dataframe(self):
36533654
expected = pd.DataFrame([[]], index=idx)
36543655
assert expected.equals(actual), (expected, actual)
36553656

3657+
@requires_sparse
3658+
def test_from_dataframe_sparse(self):
3659+
import sparse
3660+
3661+
df_base = pd.DataFrame(
3662+
{"x": range(10), "y": list("abcdefghij"), "z": np.arange(0, 100, 10)}
3663+
)
3664+
3665+
ds_sparse = Dataset.from_dataframe(df_base.set_index("x"), sparse=True)
3666+
ds_dense = Dataset.from_dataframe(df_base.set_index("x"), sparse=False)
3667+
assert isinstance(ds_sparse["y"].data, sparse.COO)
3668+
assert isinstance(ds_sparse["z"].data, sparse.COO)
3669+
ds_sparse["y"].data = ds_sparse["y"].data.todense()
3670+
ds_sparse["z"].data = ds_sparse["z"].data.todense()
3671+
assert_identical(ds_dense, ds_sparse)
3672+
3673+
ds_sparse = Dataset.from_dataframe(df_base.set_index(["x", "y"]), sparse=True)
3674+
ds_dense = Dataset.from_dataframe(df_base.set_index(["x", "y"]), sparse=False)
3675+
assert isinstance(ds_sparse["z"].data, sparse.COO)
3676+
ds_sparse["z"].data = ds_sparse["z"].data.todense()
3677+
assert_identical(ds_dense, ds_sparse)
3678+
36563679
def test_to_and_from_empty_dataframe(self):
36573680
# GH697
36583681
expected = pd.DataFrame({"foo": []})

xarray/tests/test_duck_array_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ def construct_dataarray(dim_num, dtype, contains_nan, dask):
245245

246246

247247
def from_series_or_scalar(se):
248-
try:
248+
if isinstance(se, pd.Series):
249249
return DataArray.from_series(se)
250-
except AttributeError: # scalar case
250+
else: # scalar case
251251
return DataArray(se)
252252

253253

0 commit comments

Comments
 (0)