Skip to content

Commit a86c3ff

Browse files
authored
Hierarchical coordinates in DataTree (#9063)
* Inheritance of data coordinates * Simplify __init__ * Include path name in alignment errors * Fix some mypy errors * mypy fix * simplify DataTree data model * Add to_dataset(local=True) * Fix mypy failure in tests * Fix to_zarr for inherited coords * Fix to_netcdf for heirarchical coords * Add ChainSet * Revise internal data model; remove ChainSet * add another way to construct inherited indexes * Finish refactoring error message * include inherited dimensions in HTML repr, too * Construct ChainMap objects on demand. * slightly better error message with mis-aligned data trees * mypy fix * use float64 instead of float32 for windows * clean-up per review * Add note about inheritance to .ds docs
1 parent 6c2d8c3 commit a86c3ff

8 files changed

+527
-240
lines changed

xarray/core/datatree.py

Lines changed: 228 additions & 215 deletions
Large diffs are not rendered by default.

xarray/core/datatree_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _datatree_to_netcdf(
8585
unlimited_dims = {}
8686

8787
for node in dt.subtree:
88-
ds = node.ds
88+
ds = node.to_dataset(inherited=False)
8989
group_path = node.path
9090
if ds is None:
9191
_create_empty_netcdf_group(filepath, group_path, mode, engine)
@@ -151,7 +151,7 @@ def _datatree_to_zarr(
151151
)
152152

153153
for node in dt.subtree:
154-
ds = node.ds
154+
ds = node.to_dataset(inherited=False)
155155
group_path = node.path
156156
if ds is None:
157157
_create_empty_zarr_group(store, group_path, mode)

xarray/core/formatting.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,27 @@ def dataset_repr(ds):
748748
return "\n".join(summary)
749749

750750

751+
def dims_and_coords_repr(ds) -> str:
752+
"""Partial Dataset repr for use inside DataTree inheritance errors."""
753+
summary = []
754+
755+
col_width = _calculate_col_width(ds.coords)
756+
max_rows = OPTIONS["display_max_rows"]
757+
758+
dims_start = pretty_print("Dimensions:", col_width)
759+
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
760+
summary.append(f"{dims_start}({dims_values})")
761+
762+
if ds.coords:
763+
summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows))
764+
765+
unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows)
766+
if unindexed_dims_str:
767+
summary.append(unindexed_dims_str)
768+
769+
return "\n".join(summary)
770+
771+
751772
def diff_dim_summary(a, b):
752773
if a.sizes != b.sizes:
753774
return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
@@ -1030,7 +1051,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
10301051
def _single_node_repr(node: DataTree) -> str:
10311052
"""Information about this node, not including its relationships to other nodes."""
10321053
if node.has_data or node.has_attrs:
1033-
ds_info = "\n" + repr(node.ds)
1054+
ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False))
10341055
else:
10351056
ds_info = ""
10361057
return f"Group: {node.path}{ds_info}"

xarray/core/formatting_html.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
386386
def datatree_node_repr(group_title: str, dt: DataTree) -> str:
387387
header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]
388388

389-
ds = dt.ds
389+
ds = dt._to_dataset_view(rebuild_dims=False)
390390

391391
sections = [
392392
children_section(dt.children),

xarray/core/treenode.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
138138
"To directly set parent, child needs a name, but child is unnamed"
139139
)
140140

141-
self._pre_attach(parent)
141+
self._pre_attach(parent, child_name)
142142
parentchildren = parent._children
143143
assert not any(
144144
child is self for child in parentchildren
145145
), "Tree is corrupt."
146146
parentchildren[child_name] = self
147147
self._parent = parent
148-
self._post_attach(parent)
148+
self._post_attach(parent, child_name)
149149
else:
150150
self._parent = None
151151

@@ -415,11 +415,11 @@ def _post_detach(self: Tree, parent: Tree) -> None:
415415
"""Method call after detaching from `parent`."""
416416
pass
417417

418-
def _pre_attach(self: Tree, parent: Tree) -> None:
418+
def _pre_attach(self: Tree, parent: Tree, name: str) -> None:
419419
"""Method call before attaching to `parent`."""
420420
pass
421421

422-
def _post_attach(self: Tree, parent: Tree) -> None:
422+
def _post_attach(self: Tree, parent: Tree, name: str) -> None:
423423
"""Method call after attaching to `parent`."""
424424
pass
425425

@@ -567,6 +567,9 @@ def same_tree(self, other: Tree) -> bool:
567567
return self.root is other.root
568568

569569

570+
AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode")
571+
572+
570573
class NamedNode(TreeNode, Generic[Tree]):
571574
"""
572575
A TreeNode which knows its own name.
@@ -606,10 +609,9 @@ def __repr__(self, level=0):
606609
def __str__(self) -> str:
607610
return f"NamedNode('{self.name}')" if self.name else "NamedNode()"
608611

609-
def _post_attach(self: NamedNode, parent: NamedNode) -> None:
612+
def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
610613
"""Ensures child has name attribute corresponding to key under which it has been stored."""
611-
key = next(k for k, v in parent.children.items() if v is self)
612-
self.name = key
614+
self.name = name
613615

614616
@property
615617
def path(self) -> str:

xarray/tests/test_backends_datatree.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, cast
44

55
import pytest
66

7+
import xarray as xr
78
from xarray.backends.api import open_datatree
9+
from xarray.core.datatree import DataTree
810
from xarray.testing import assert_equal
911
from xarray.tests import (
1012
requires_h5netcdf,
@@ -13,11 +15,11 @@
1315
)
1416

1517
if TYPE_CHECKING:
16-
from xarray.backends.api import T_NetcdfEngine
18+
from xarray.core.datatree_io import T_DataTreeNetcdfEngine
1719

1820

1921
class DatatreeIOBase:
20-
engine: T_NetcdfEngine | None = None
22+
engine: T_DataTreeNetcdfEngine | None = None
2123

2224
def test_to_netcdf(self, tmpdir, simple_datatree):
2325
filepath = tmpdir / "test.nc"
@@ -27,6 +29,21 @@ def test_to_netcdf(self, tmpdir, simple_datatree):
2729
roundtrip_dt = open_datatree(filepath, engine=self.engine)
2830
assert_equal(original_dt, roundtrip_dt)
2931

32+
def test_to_netcdf_inherited_coords(self, tmpdir):
33+
filepath = tmpdir / "test.nc"
34+
original_dt = DataTree.from_dict(
35+
{
36+
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
37+
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
38+
}
39+
)
40+
original_dt.to_netcdf(filepath, engine=self.engine)
41+
42+
roundtrip_dt = open_datatree(filepath, engine=self.engine)
43+
assert_equal(original_dt, roundtrip_dt)
44+
subtree = cast(DataTree, roundtrip_dt["/sub"])
45+
assert "x" not in subtree.to_dataset(inherited=False).coords
46+
3047
def test_netcdf_encoding(self, tmpdir, simple_datatree):
3148
filepath = tmpdir / "test.nc"
3249
original_dt = simple_datatree
@@ -48,12 +65,12 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
4865

4966
@requires_netCDF4
5067
class TestNetCDF4DatatreeIO(DatatreeIOBase):
51-
engine: T_NetcdfEngine | None = "netcdf4"
68+
engine: T_DataTreeNetcdfEngine | None = "netcdf4"
5269

5370

5471
@requires_h5netcdf
5572
class TestH5NetCDFDatatreeIO(DatatreeIOBase):
56-
engine: T_NetcdfEngine | None = "h5netcdf"
73+
engine: T_DataTreeNetcdfEngine | None = "h5netcdf"
5774

5875

5976
@requires_zarr
@@ -119,3 +136,18 @@ def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree):
119136
# with default settings, to_zarr should not overwrite an existing dir
120137
with pytest.raises(zarr.errors.ContainsGroupError):
121138
simple_datatree.to_zarr(tmpdir)
139+
140+
def test_to_zarr_inherited_coords(self, tmpdir):
141+
original_dt = DataTree.from_dict(
142+
{
143+
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
144+
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
145+
}
146+
)
147+
filepath = tmpdir / "test.zarr"
148+
original_dt.to_zarr(filepath)
149+
150+
roundtrip_dt = open_datatree(filepath, engine="zarr")
151+
assert_equal(original_dt, roundtrip_dt)
152+
subtree = cast(DataTree, roundtrip_dt["/sub"])
153+
assert "x" not in subtree.to_dataset(inherited=False).coords

0 commit comments

Comments
 (0)