Skip to content

Commit 1f5889a

Browse files
keewisTomNicholas
andauthored
implement dask methods on DataTree (#9670)
* implement `compute` and `load` * also shallow-copy variables * implement `chunksizes` * add tests for `load` * add tests for `chunksizes` * improve the `load` tests using `DataTree.chunksizes` * add a test for `compute` * un-xfail a xpassing test * implement and test `DataTree.chunk` * link to `Dataset.load` Co-authored-by: Tom Nicholas <[email protected]> * use `tree.subtree` to get absolute paths * filter out missing dims before delegating to `Dataset.chunk` * fix the type hints for `DataTree.chunksizes` * try using `self.from_dict` instead * type-hint intermediate test variables * use `_node_dims` instead * raise on unknown chunk dim * check that errors in `chunk` are raised properly * adapt the docstrings of the new methods * allow computing / loading unchunked trees * reword the `chunksizes` properties * also freeze the top-level chunk sizes * also reword `DataArray.chunksizes` * fix a copy-paste error * same for `NamedArray.chunksizes` --------- Co-authored-by: Tom Nicholas <[email protected]>
1 parent 521b087 commit 1f5889a

File tree

5 files changed

+342
-14
lines changed

5 files changed

+342
-14
lines changed

xarray/core/dataarray.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,8 +1343,10 @@ def chunks(self) -> tuple[tuple[int, ...], ...] | None:
13431343
@property
13441344
def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:
13451345
"""
1346-
Mapping from dimension names to block lengths for this dataarray's data, or None if
1347-
the underlying data is not a dask array.
1346+
Mapping from dimension names to block lengths for this dataarray's data.
1347+
1348+
If this dataarray does not contain chunked arrays, the mapping will be empty.
1349+
13481350
Cannot be modified directly, but can be modified by calling .chunk().
13491351
13501352
Differs from DataArray.chunks because it returns a mapping of dimensions to chunk shapes

xarray/core/dataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,8 +2658,10 @@ def info(self, buf: IO | None = None) -> None:
26582658
@property
26592659
def chunks(self) -> Mapping[Hashable, tuple[int, ...]]:
26602660
"""
2661-
Mapping from dimension names to block lengths for this dataset's data, or None if
2662-
the underlying data is not a dask array.
2661+
Mapping from dimension names to block lengths for this dataset's data.
2662+
2663+
If this dataset does not contain chunked arrays, the mapping will be empty.
2664+
26632665
Cannot be modified directly, but can be modified by calling .chunk().
26642666
26652667
Same as Dataset.chunksizes, but maintained for backwards compatibility.
@@ -2675,8 +2677,10 @@ def chunks(self) -> Mapping[Hashable, tuple[int, ...]]:
26752677
@property
26762678
def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
26772679
"""
2678-
Mapping from dimension names to block lengths for this dataset's data, or None if
2679-
the underlying data is not a dask array.
2680+
Mapping from dimension names to block lengths for this dataset's data.
2681+
2682+
If this dataset does not contain chunked arrays, the mapping will be empty.
2683+
26802684
Cannot be modified directly, but can be modified by calling .chunk().
26812685
26822686
Same as Dataset.chunks.

xarray/core/datatree.py

Lines changed: 208 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from xarray.core._aggregations import DataTreeAggregations
1919
from xarray.core._typed_ops import DataTreeOpsMixin
2020
from xarray.core.alignment import align
21-
from xarray.core.common import TreeAttrAccessMixin
21+
from xarray.core.common import TreeAttrAccessMixin, get_chunksizes
2222
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
2323
from xarray.core.dataarray import DataArray
2424
from xarray.core.dataset import Dataset, DataVariables
@@ -49,6 +49,8 @@
4949
parse_dims_as_set,
5050
)
5151
from xarray.core.variable import Variable
52+
from xarray.namedarray.parallelcompat import get_chunked_array_type
53+
from xarray.namedarray.pycompat import is_chunked_array
5254

5355
try:
5456
from xarray.core.variable import calculate_dimensions
@@ -68,8 +70,11 @@
6870
ErrorOptions,
6971
ErrorOptionsWithWarn,
7072
NetcdfWriteModes,
73+
T_ChunkDimFreq,
74+
T_ChunksFreq,
7175
ZarrWriteModes,
7276
)
77+
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
7378

7479
# """
7580
# DEVELOPERS' NOTE
@@ -862,9 +867,9 @@ def _copy_node(
862867
) -> Self:
863868
"""Copy just one node of a tree."""
864869
new_node = super()._copy_node(inherit=inherit, deep=deep, memo=memo)
865-
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)
866-
if deep:
867-
data = data._copy(deep=True, memo=memo)
870+
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)._copy(
871+
deep=deep, memo=memo
872+
)
868873
new_node._set_node_data(data)
869874
return new_node
870875

@@ -1896,3 +1901,202 @@ def apply_indexers(dataset, node_indexers):
18961901

18971902
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
18981903
return self._selective_indexing(apply_indexers, indexers)
1904+
1905+
def load(self, **kwargs) -> Self:
1906+
"""Manually trigger loading and/or computation of this datatree's data
1907+
from disk or a remote source into memory and return this datatree.
1908+
Unlike compute, the original datatree is modified and returned.
1909+
1910+
Normally, it should not be necessary to call this method in user code,
1911+
because all xarray functions should either work on deferred data or
1912+
load data automatically. However, this method can be necessary when
1913+
working with many file objects on disk.
1914+
1915+
Parameters
1916+
----------
1917+
**kwargs : dict
1918+
Additional keyword arguments passed on to ``dask.compute``.
1919+
1920+
See Also
1921+
--------
1922+
Dataset.load
1923+
dask.compute
1924+
"""
1925+
# access .data to coerce everything to numpy or dask arrays
1926+
lazy_data = {
1927+
path: {
1928+
k: v._data
1929+
for k, v in node.variables.items()
1930+
if is_chunked_array(v._data)
1931+
}
1932+
for path, node in self.subtree_with_keys
1933+
}
1934+
flat_lazy_data = {
1935+
(path, var_name): array
1936+
for path, node in lazy_data.items()
1937+
for var_name, array in node.items()
1938+
}
1939+
if flat_lazy_data:
1940+
chunkmanager = get_chunked_array_type(*flat_lazy_data.values())
1941+
1942+
# evaluate all the chunked arrays simultaneously
1943+
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
1944+
*flat_lazy_data.values(), **kwargs
1945+
)
1946+
1947+
for (path, var_name), data in zip(
1948+
flat_lazy_data, evaluated_data, strict=False
1949+
):
1950+
self[path].variables[var_name].data = data
1951+
1952+
# load everything else sequentially
1953+
for node in self.subtree:
1954+
for k, v in node.variables.items():
1955+
if k not in lazy_data:
1956+
v.load()
1957+
1958+
return self
1959+
1960+
def compute(self, **kwargs) -> Self:
1961+
"""Manually trigger loading and/or computation of this datatree's data
1962+
from disk or a remote source into memory and return a new datatree.
1963+
Unlike load, the original datatree is left unaltered.
1964+
1965+
Normally, it should not be necessary to call this method in user code,
1966+
because all xarray functions should either work on deferred data or
1967+
load data automatically. However, this method can be necessary when
1968+
working with many file objects on disk.
1969+
1970+
Parameters
1971+
----------
1972+
**kwargs : dict
1973+
Additional keyword arguments passed on to ``dask.compute``.
1974+
1975+
Returns
1976+
-------
1977+
object : DataTree
1978+
New object with lazy data variables and coordinates as in-memory arrays.
1979+
1980+
See Also
1981+
--------
1982+
dask.compute
1983+
"""
1984+
new = self.copy(deep=False)
1985+
return new.load(**kwargs)
1986+
1987+
@property
1988+
def chunksizes(self) -> Mapping[str, Mapping[Hashable, tuple[int, ...]]]:
1989+
"""
1990+
Mapping from group paths to a mapping of chunksizes.
1991+
1992+
If there's no chunked data in a group, the corresponding mapping of chunksizes will be empty.
1993+
1994+
Cannot be modified directly, but can be modified by calling .chunk().
1995+
1996+
See Also
1997+
--------
1998+
DataTree.chunk
1999+
Dataset.chunksizes
2000+
"""
2001+
return Frozen(
2002+
{
2003+
node.path: get_chunksizes(node.variables.values())
2004+
for node in self.subtree
2005+
}
2006+
)
2007+
2008+
def chunk(
2009+
self,
2010+
chunks: T_ChunksFreq = {}, # noqa: B006 # {} even though it's technically unsafe, is being used intentionally here (#4667)
2011+
name_prefix: str = "xarray-",
2012+
token: str | None = None,
2013+
lock: bool = False,
2014+
inline_array: bool = False,
2015+
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
2016+
from_array_kwargs=None,
2017+
**chunks_kwargs: T_ChunkDimFreq,
2018+
) -> Self:
2019+
"""Coerce all arrays in all groups in this tree into dask arrays with the given
2020+
chunks.
2021+
2022+
Non-dask arrays in this tree will be converted to dask arrays. Dask
2023+
arrays will be rechunked to the given chunk sizes.
2024+
2025+
If neither chunks is not provided for one or more dimensions, chunk
2026+
sizes along that dimension will not be updated; non-dask arrays will be
2027+
converted into dask arrays with a single block.
2028+
2029+
Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
2030+
2031+
Parameters
2032+
----------
2033+
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
2034+
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
2035+
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
2036+
name_prefix : str, default: "xarray-"
2037+
Prefix for the name of any new dask arrays.
2038+
token : str, optional
2039+
Token uniquely identifying this datatree.
2040+
lock : bool, default: False
2041+
Passed on to :py:func:`dask.array.from_array`, if the array is not
2042+
already as dask array.
2043+
inline_array: bool, default: False
2044+
Passed on to :py:func:`dask.array.from_array`, if the array is not
2045+
already as dask array.
2046+
chunked_array_type: str, optional
2047+
Which chunked array type to coerce this datatree's arrays to.
2048+
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system.
2049+
Experimental API that should not be relied upon.
2050+
from_array_kwargs: dict, optional
2051+
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
2052+
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
2053+
For example, with dask as the default chunked array type, this method would pass additional kwargs
2054+
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
2055+
**chunks_kwargs : {dim: chunks, ...}, optional
2056+
The keyword arguments form of ``chunks``.
2057+
One of chunks or chunks_kwargs must be provided
2058+
2059+
Returns
2060+
-------
2061+
chunked : xarray.DataTree
2062+
2063+
See Also
2064+
--------
2065+
Dataset.chunk
2066+
Dataset.chunksizes
2067+
xarray.unify_chunks
2068+
dask.array.from_array
2069+
"""
2070+
# don't support deprecated ways of passing chunks
2071+
if not isinstance(chunks, Mapping):
2072+
raise TypeError(
2073+
f"invalid type for chunks: {type(chunks)}. Only mappings are supported."
2074+
)
2075+
combined_chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
2076+
2077+
all_dims = self._get_all_dims()
2078+
2079+
bad_dims = combined_chunks.keys() - all_dims
2080+
if bad_dims:
2081+
raise ValueError(
2082+
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(all_dims)}"
2083+
)
2084+
2085+
rechunked_groups = {
2086+
path: node.dataset.chunk(
2087+
{
2088+
dim: size
2089+
for dim, size in combined_chunks.items()
2090+
if dim in node._node_dims
2091+
},
2092+
name_prefix=name_prefix,
2093+
token=token,
2094+
lock=lock,
2095+
inline_array=inline_array,
2096+
chunked_array_type=chunked_array_type,
2097+
from_array_kwargs=from_array_kwargs,
2098+
)
2099+
for path, node in self.subtree_with_keys
2100+
}
2101+
2102+
return self.from_dict(rechunked_groups, name=self.name)

xarray/namedarray/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,10 @@ def chunksizes(
725725
self,
726726
) -> Mapping[_Dim, _Shape]:
727727
"""
728-
Mapping from dimension names to block lengths for this namedArray's data, or None if
729-
the underlying data is not a dask array.
728+
Mapping from dimension names to block lengths for this NamedArray's data.
729+
730+
If this NamedArray does not contain chunked arrays, the mapping will be empty.
731+
730732
Cannot be modified directly, but can be modified by calling .chunk().
731733
732734
Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes

0 commit comments

Comments
 (0)