Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Map over multiple subtrees #32

Merged
merged 22 commits into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c857a30
pseudocode ideas for generalizing map_over_subtree
TomNicholas Aug 26, 2021
871802a
pseudocode for a generalized map_over_subtree (still only one return …
TomNicholas Aug 26, 2021
2b61af3
pseudocode for mapping but now multiple return values
TomNicholas Aug 27, 2021
9045c23
pseudocode for mapping but with multiple return values
TomNicholas Aug 27, 2021
a335bd5
Merge branch 'main' into map_over_multiple_subtrees
TomNicholas Aug 27, 2021
c30dd16
Merge branch 'main' into check_isomorphism
TomNicholas Aug 27, 2021
1e4c68d
check_isomorphism works and has tests
TomNicholas Aug 27, 2021
e14f7a9
cleaned up the mapping tests a bit
TomNicholas Aug 27, 2021
600bc3d
Merge branch 'check_isomorphism' into map_over_multiple_subtrees
TomNicholas Aug 27, 2021
f8d9801
tests for mapping over multiple trees
TomNicholas Aug 30, 2021
70c5c7d
incorrect pseudocode attempt to map over multiple subtrees
TomNicholas Aug 30, 2021
a030b1f
small improvements
TomNicholas Aug 30, 2021
ef7ee4d
fixed test
TomNicholas Aug 30, 2021
0d9af33
zipping of multiple arguments
TomNicholas Aug 30, 2021
64b149a
passes for mapping over a single tree
TomNicholas Sep 1, 2021
a17aa20
successfully maps over multiple trees
TomNicholas Sep 1, 2021
7057027
successfully returns multiple trees
TomNicholas Sep 1, 2021
06b22bf
filled out all tests
TomNicholas Sep 1, 2021
c24ca3f
checking types now works for trees with only one node
TomNicholas Sep 2, 2021
02814d8
improved docstring
TomNicholas Sep 2, 2021
ca35e9d
merge in main
TomNicholas Sep 2, 2021
32c3303
merge most recent version of main
TomNicholas Sep 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,12 @@ def __init__(
else:
node_path, node_name = "/", path

relative_path = node_path.replace(self.name, "")

# Create and set new node
new_node = DataNode(name=node_name, data=data)
self.set_node(
node_path,
relative_path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
Expand Down
192 changes: 166 additions & 26 deletions datatree/mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
from itertools import repeat

from anytree.iterators import LevelOrderIter
from xarray import DataArray, Dataset

from .treenode import TreeNode

Expand Down Expand Up @@ -43,11 +45,11 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):

if not isinstance(subtree_a, TreeNode):
raise TypeError(
f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}"
f"Argument `subtree_a` is not a tree, it is of type {type(subtree_a)}"
)
if not isinstance(subtree_b, TreeNode):
raise TypeError(
f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}"
f"Argument `subtree_b` is not a tree, it is of type {type(subtree_b)}"
)

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
Expand Down Expand Up @@ -83,57 +85,195 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):

def map_over_subtree(func):
"""
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.

Applies a function to every dataset in this subtree, returning a new tree which stores the results.
Applies a function to every dataset in one or more subtrees, returning new trees which store the results.

The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
descendant nodes. The returned tree will have the same structure as the original subtree.
The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have
the same structure as the supplied trees.

func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
`func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
returned value that is one of these types will be stacked into a separate tree before returning all of them.

The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
similarly, but all the output trees will have nodes named in the same way as the first tree passed.

Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.

`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.

(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets \
via .ds .
**kwargs : Any
Keyword arguments passed on to `func`.
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .

Returns
-------
mapped : callable
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
each node.

See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
DataTree.subtree
"""

# TODO examples in the docstring

# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(tree, *args, **kwargs):
def _map_over_subtree(*args, **kwargs):
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from .datatree import DataTree

all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]

if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")

for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
_check_isomorphic(first_tree, other_tree, require_names_equal=False)

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasets = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
]
node_kwargs_as_datasets = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.ds if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
)
)

# Recreate and act on root node
from .datatree import DataNode
# Now we can call func on the data in this particular set of corresponding nodes
results = (
func(*node_args_as_datasets, **node_kwargs_as_datasets)
if node_of_first_tree.has_data
else None
)

out_tree = DataNode(name=tree.name, data=tree.ds)
if out_tree.has_data:
out_tree.ds = func(out_tree.ds, *args, **kwargs)
# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.pathstr] = results

# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)

# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.pathstr
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None
out_tree_contents[p] = output_node_data

new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents)
result_trees.append(new_tree)

# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)

# Act on every other node in the tree, and rebuild from results
for node in tree.descendants:
# TODO make a proper relative_path method
relative_path = node.pathstr.replace(tree.pathstr, "")
result = func(node.ds, *args, **kwargs) if node.has_data else None
out_tree[relative_path] = result
return _map_over_subtree

return out_tree

return _map_over_subtree
def _check_single_set_return_values(path_to_node, obj):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
elif isinstance(obj, tuple):
for r in obj:
if not isinstance(r, (Dataset, DataArray)):
raise TypeError(
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
f"of type {type(r)}, not Dataset or DataArray."
)
return len(obj)
else:
raise TypeError(
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
f"Dataset or DataArray, nor a tuple of such types."
)


def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

if all(r is None for r in returned_objects.values()):
raise TypeError(
"Called supplied function on all nodes but found a return value of None for"
"all of them."
)

result_data_objects = [
(path_to_node, r)
for path_to_node, r in returned_objects.items()
if r is not None
]

if len(result_data_objects) == 1:
# Only one node in the tree: no need to check consistency of results between nodes
path_to_node, result = result_data_objects[0]
num_return_values = _check_single_set_return_values(path_to_node, result)
else:
prev_path, _ = result_data_objects[0]
prev_num_return_values, num_return_values = None, None
for path_to_node, obj in result_data_objects[1:]:
num_return_values = _check_single_set_return_values(path_to_node, obj)

if (
num_return_values != prev_num_return_values
and prev_num_return_values is not None
):
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
f"{prev_num_return_values} separate return values."
)

prev_path, prev_num_return_values = path_to_node, num_return_values

return num_return_values
6 changes: 1 addition & 5 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@


def assert_tree_equal(dt_a, dt_b):
assert dt_a.name == dt_b.name
assert dt_a.parent is dt_b.parent

assert dt_a.ds.equals(dt_b.ds)
for a, b in zip(dt_a.descendants, dt_b.descendants):
for a, b in zip(dt_a.subtree, dt_b.subtree):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
Expand Down Expand Up @@ -321,7 +319,6 @@ def test_to_netcdf(self, tmpdir):
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath)

assert_tree_equal(original_dt, roundtrip_dt)

def test_to_zarr(self, tmpdir):
Expand All @@ -332,5 +329,4 @@ def test_to_zarr(self, tmpdir):
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")

assert_tree_equal(original_dt, roundtrip_dt)
Loading