Skip to content

Commit 60f3e74

Browse files
Migrate datatree mapping.py (#8948)
* DAS-2064: rename/relocate mapping.py -> xarray.core.datatree_mapping.py DAS-2064: fix circular import issue. * DAS-2064 - Minor changes to datatree_mapping.py. --------- Co-authored-by: Matt Savoie <[email protected]>
1 parent 9a37053 commit 60f3e74

File tree

8 files changed

+35
-33
lines changed

8 files changed

+35
-33
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Bug fixes
3636

3737
Internal Changes
3838
~~~~~~~~~~~~~~~~
39+
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
40+
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
41+
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.
3942

4043

4144
.. _whats-new.2024.03.0:

xarray/core/datatree.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
from xarray.core.coordinates import DatasetCoordinates
1919
from xarray.core.dataarray import DataArray
2020
from xarray.core.dataset import Dataset, DataVariables
21+
from xarray.core.datatree_mapping import (
22+
TreeIsomorphismError,
23+
check_isomorphic,
24+
map_over_subtree,
25+
)
2126
from xarray.core.indexes import Index, Indexes
2227
from xarray.core.merge import dataset_update_method
2328
from xarray.core.options import OPTIONS as XR_OPTS
@@ -36,11 +41,6 @@
3641
from xarray.datatree_.datatree.formatting_html import (
3742
datatree_repr as datatree_repr_html,
3843
)
39-
from xarray.datatree_.datatree.mapping import (
40-
TreeIsomorphismError,
41-
check_isomorphic,
42-
map_over_subtree,
43-
)
4444
from xarray.datatree_.datatree.ops import (
4545
DataTreeArithmeticMixin,
4646
MappedDatasetMethodsMixin,

xarray/datatree_/datatree/mapping.py renamed to xarray/core/datatree_mapping.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
import sys
55
from itertools import repeat
66
from textwrap import dedent
7-
from typing import TYPE_CHECKING, Callable, Tuple
7+
from typing import TYPE_CHECKING, Callable
88

99
from xarray import DataArray, Dataset
10-
1110
from xarray.core.iterators import LevelOrderIter
1211
from xarray.core.treenode import NodePath, TreeNode
1312

@@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
8483
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
8584
path_a, path_b = node_a.path, node_b.path
8685

87-
if require_names_equal:
88-
if node_a.name != node_b.name:
89-
diff = dedent(
90-
f"""\
86+
if require_names_equal and node_a.name != node_b.name:
87+
diff = dedent(
88+
f"""\
9189
Node '{path_a}' in the left object has name '{node_a.name}'
9290
Node '{path_b}' in the right object has name '{node_b.name}'"""
93-
)
94-
return diff
91+
)
92+
return diff
9593

9694
if len(node_a.children) != len(node_b.children):
9795
diff = dedent(
@@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
125123
func : callable
126124
Function to apply to datasets with signature:
127125
128-
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
126+
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
129127
130128
(i.e. func must accept at least one Dataset and return at least one Dataset.)
131129
Function will not be applied to any nodes without datasets.
@@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
154152
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
155153

156154
@functools.wraps(func)
157-
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
155+
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
158156
"""Internal function which maps func over every node in tree, returning a tree of the results."""
159157
from xarray.core.datatree import DataTree
160158

@@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
259257
return _map_over_subtree
260258

261259

262-
def _handle_errors_with_path_context(path):
260+
def _handle_errors_with_path_context(path: str):
263261
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
264262

265263
def decorator(func):
266264
def wrapper(*args, **kwargs):
267265
try:
268266
return func(*args, **kwargs)
269267
except Exception as e:
270-
if sys.version_info >= (3, 11):
271-
# Add the context information to the error message
272-
e.add_note(
273-
f"Raised whilst mapping function over node with path {path}"
274-
)
268+
# Add the context information to the error message
269+
add_note(
270+
e, f"Raised whilst mapping function over node with path {path}"
271+
)
275272
raise
276273

277274
return wrapper
@@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
287284
err.add_note(msg)
288285

289286

290-
def _check_single_set_return_values(path_to_node, obj):
287+
def _check_single_set_return_values(
288+
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
289+
):
291290
"""Check types returned from single evaluation of func, and return number of return values received from func."""
292291
if isinstance(obj, (Dataset, DataArray)):
293292
return 1

xarray/core/iterators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from collections import abc
43
from collections.abc import Iterator
54
from typing import Callable
65

@@ -9,7 +8,7 @@
98
"""These iterators are copied from anytree.iterators, with minor modifications."""
109

1110

12-
class LevelOrderIter(abc.Iterator):
11+
class LevelOrderIter(Iterator):
1312
"""Iterate over tree applying level-order strategy starting at `node`.
1413
This is the iterator used by `DataTree` to traverse nodes.
1514

xarray/datatree_/datatree/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
# import public API
2-
from .mapping import TreeIsomorphismError, map_over_subtree
32
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError
43

54

65
__all__ = (
7-
"TreeIsomorphismError",
86
"InvalidTreeError",
97
"NotFoundInTreeError",
10-
"map_over_subtree",
118
)

xarray/datatree_/datatree/formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from xarray.core.formatting import _compat_to_str, diff_dataset_repr
44

5-
from xarray.datatree_.datatree.mapping import diff_treestructure
5+
from xarray.core.datatree_mapping import diff_treestructure
66
from xarray.datatree_.datatree.render import RenderTree
77

88
if TYPE_CHECKING:

xarray/datatree_/datatree/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from xarray import Dataset
44

5-
from .mapping import map_over_subtree
5+
from xarray.core.datatree_mapping import map_over_subtree
66

77
"""
88
Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree.

xarray/datatree_/datatree/tests/test_mapping.py renamed to xarray/tests/test_datatree_mapping.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import numpy as np
22
import pytest
3-
import xarray as xr
43

4+
import xarray as xr
55
from xarray.core.datatree import DataTree
6-
from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree
6+
from xarray.core.datatree_mapping import (
7+
TreeIsomorphismError,
8+
check_isomorphic,
9+
map_over_subtree,
10+
)
711
from xarray.datatree_.datatree.testing import assert_equal
812

913
empty = xr.Dataset()
@@ -12,7 +16,7 @@
1216
class TestCheckTreesIsomorphic:
1317
def test_not_a_tree(self):
1418
with pytest.raises(TypeError, match="not a tree"):
15-
check_isomorphic("s", 1)
19+
check_isomorphic("s", 1) # type: ignore[arg-type]
1620

1721
def test_different_widths(self):
1822
dt1 = DataTree.from_dict(d={"a": empty})
@@ -69,7 +73,7 @@ def test_not_isomorphic_complex_tree(self, create_test_datatree):
6973
def test_checking_from_root(self, create_test_datatree):
7074
dt1 = create_test_datatree()
7175
dt2 = create_test_datatree()
72-
real_root = DataTree(name="real root")
76+
real_root: DataTree = DataTree(name="real root")
7377
dt2.name = "not_real_root"
7478
dt2.parent = real_root
7579
with pytest.raises(TreeIsomorphismError):

0 commit comments

Comments
 (0)