Skip to content

Commit fe87162

Browse files
malmans2crusaderky
andauthored
Add xr.unify_chunks() top level method (#5445)
Co-authored-by: crusaderky <[email protected]>
1 parent 4e61a26 commit fe87162

File tree

7 files changed

+90
-36
lines changed

7 files changed

+90
-36
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Top-level functions
3636
map_blocks
3737
show_versions
3838
set_options
39+
unify_chunks
3940

4041
Dataset
4142
=======

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ v0.18.3 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24+
- New top-level function :py:func:`unify_chunks`.
25+
By `Mattia Almansi <https://github.com/malmans2>`_.
2426
- Allow assigning values to a subset of a dataset using positional or label-based
2527
indexing (:issue:`3015`, :pull:`5362`).
2628
By `Matthias Göbel <https://github.com/matzegoebel>`_.

xarray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .core.alignment import align, broadcast
1919
from .core.combine import combine_by_coords, combine_nested
2020
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
21-
from .core.computation import apply_ufunc, corr, cov, dot, polyval, where
21+
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
2222
from .core.concat import concat
2323
from .core.dataarray import DataArray
2424
from .core.dataset import Dataset
@@ -74,6 +74,7 @@
7474
"save_mfdataset",
7575
"set_options",
7676
"show_versions",
77+
"unify_chunks",
7778
"where",
7879
"zeros_like",
7980
# Classes

xarray/core/computation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Functions for applying functions that act on arrays to xarray's labeled data.
33
"""
4+
from __future__ import annotations
5+
46
import functools
57
import itertools
68
import operator
@@ -19,6 +21,7 @@
1921
Optional,
2022
Sequence,
2123
Tuple,
24+
TypeVar,
2225
Union,
2326
)
2427

@@ -34,8 +37,11 @@
3437

3538
if TYPE_CHECKING:
3639
from .coordinates import Coordinates # noqa
40+
from .dataarray import DataArray
3741
from .dataset import Dataset
3842

43+
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
44+
3945
_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
4046
_DEFAULT_NAME = utils.ReprObject("<default-name>")
4147
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
@@ -1721,3 +1727,61 @@ def _calc_idxminmax(
17211727
res.attrs = indx.attrs
17221728

17231729
return res
1730+
1731+
1732+
def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]:
1733+
"""
1734+
Given any number of Dataset and/or DataArray objects, returns
1735+
new objects with unified chunk size along all chunked dimensions.
1736+
1737+
Returns
1738+
-------
1739+
unified (DataArray or Dataset) – Tuple of objects with the same type as
1740+
*objects with consistent chunk sizes for all dask-array variables
1741+
1742+
See Also
1743+
--------
1744+
dask.array.core.unify_chunks
1745+
"""
1746+
from .dataarray import DataArray
1747+
1748+
# Convert all objects to datasets
1749+
datasets = [
1750+
obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy()
1751+
for obj in objects
1752+
]
1753+
1754+
# Get argumets to pass into dask.array.core.unify_chunks
1755+
unify_chunks_args = []
1756+
sizes: dict[Hashable, int] = {}
1757+
for ds in datasets:
1758+
for v in ds._variables.values():
1759+
if v.chunks is not None:
1760+
# Check that sizes match across different datasets
1761+
for dim, size in v.sizes.items():
1762+
try:
1763+
if sizes[dim] != size:
1764+
raise ValueError(
1765+
f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}"
1766+
)
1767+
except KeyError:
1768+
sizes[dim] = size
1769+
unify_chunks_args += [v._data, v._dims]
1770+
1771+
# No dask arrays: Return inputs
1772+
if not unify_chunks_args:
1773+
return objects
1774+
1775+
# Run dask.array.core.unify_chunks
1776+
from dask.array.core import unify_chunks
1777+
1778+
_, dask_data = unify_chunks(*unify_chunks_args)
1779+
dask_data_iter = iter(dask_data)
1780+
out = []
1781+
for obj, ds in zip(objects, datasets):
1782+
for k, v in ds._variables.items():
1783+
if v.chunks is not None:
1784+
ds._variables[k] = v.copy(data=next(dask_data_iter))
1785+
out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds)
1786+
1787+
return tuple(out)

xarray/core/dataarray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from .arithmetic import DataArrayArithmetic
4646
from .common import AbstractArray, DataWithCoords
47+
from .computation import unify_chunks
4748
from .coordinates import (
4849
DataArrayCoordinates,
4950
assert_coordinate_consistent,
@@ -3686,8 +3687,8 @@ def unify_chunks(self) -> "DataArray":
36863687
--------
36873688
dask.array.core.unify_chunks
36883689
"""
3689-
ds = self._to_temp_dataset().unify_chunks()
3690-
return self._from_temp_dataset(ds)
3690+
3691+
return unify_chunks(self)[0]
36913692

36923693
def map_blocks(
36933694
self,

xarray/core/dataset.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
5454
from .arithmetic import DatasetArithmetic
5555
from .common import DataWithCoords, _contains_datetime_like_objects
56+
from .computation import unify_chunks
5657
from .coordinates import (
5758
DatasetCoordinates,
5859
assert_coordinate_consistent,
@@ -6566,37 +6567,7 @@ def unify_chunks(self) -> "Dataset":
65666567
dask.array.core.unify_chunks
65676568
"""
65686569

6569-
try:
6570-
self.chunks
6571-
except ValueError: # "inconsistent chunks"
6572-
pass
6573-
else:
6574-
# No variables with dask backend, or all chunks are already aligned
6575-
return self.copy()
6576-
6577-
# import dask is placed after the quick exit test above to allow
6578-
# running this method if dask isn't installed and there are no chunks
6579-
import dask.array
6580-
6581-
ds = self.copy()
6582-
6583-
dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)}
6584-
6585-
dask_array_names = []
6586-
dask_unify_args = []
6587-
for name, variable in ds.variables.items():
6588-
if isinstance(variable.data, dask.array.Array):
6589-
dims_tuple = [dims_pos_map[dim] for dim in variable.dims]
6590-
dask_array_names.append(name)
6591-
dask_unify_args.append(variable.data)
6592-
dask_unify_args.append(dims_tuple)
6593-
6594-
_, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args)
6595-
6596-
for name, new_array in zip(dask_array_names, rechunked_arrays):
6597-
ds.variables[name]._data = new_array
6598-
6599-
return ds
6570+
return unify_chunks(self)[0]
66006571

66016572
def map_blocks(
66026573
self,

xarray/tests/test_dask.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,12 +1069,26 @@ def test_unify_chunks(map_ds):
10691069
with pytest.raises(ValueError, match=r"inconsistent chunks"):
10701070
ds_copy.chunks
10711071

1072-
expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)}
1072+
expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)}
10731073
with raise_if_dask_computes():
10741074
actual_chunks = ds_copy.unify_chunks().chunks
1075-
expected_chunks == actual_chunks
1075+
assert actual_chunks == expected_chunks
10761076
assert_identical(map_ds, ds_copy.unify_chunks())
10771077

1078+
out_a, out_b = xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy"))
1079+
assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
1080+
assert out_b.chunks == expected_chunks
1081+
1082+
# Test unordered dims
1083+
da = ds_copy["cxy"]
1084+
out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1}))
1085+
assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
1086+
assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2))
1087+
1088+
# Test mismatch
1089+
with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"):
1090+
xr.unify_chunks(da, da.isel(x=slice(2)))
1091+
10781092

10791093
@pytest.mark.parametrize("obj", [make_ds(), make_da()])
10801094
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)