Skip to content

Commit 47f1859

Browse files
jerryxyjOrbax Authors
authored andcommitted
Add prune_tree utility and integrate it into Orbax Model export.
This change introduces a `prune_tree` function to `tree_util.py` which removes tree leaves not matching a specified type. The Orbax Model export pipeline is updated with a new option, `prune_custom_pytree_nodes`. When this option is enabled, custom PyTree nodes that are not `ShloTensorSpec` instances are pruned (replaced with `None`) instead of causing an error during signature conversion. This allows for more flexible handling of PyTrees containing non-tensor metadata. PiperOrigin-RevId: 810549387
1 parent 27c35bd commit 47f1859

File tree

7 files changed

+55
-1
lines changed

7 files changed

+55
-1
lines changed

export/orbax/export/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class ExportModelType(enum.Enum):
6262
# Jax2obm_kwargs key for input polymorphic constraints.
6363
POLYMORPHIC_CONSTRAINTS = 'polymorphic_constraints'
6464

65+
# Jax2obm_kwargs key for pruning custom pytree nodes.
66+
PRUNE_CUSTOM_PYTREE_NODES = 'prune_custom_pytree_nodes'
67+
6568
# Default weights name to use if a checkpoint path is provided but a weights_
6669
# name kwarg was not provided in the jax2obm_kwargs.
6770
DEFAULT_WEIGHTS_NAME = 'weights'

export/orbax/export/modules/obm_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def __init__(
104104
self._checkpoint_path: str = None
105105
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
106106
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)
107+
self._prune_custom_pytree_nodes = jax2obm_kwargs.get(
108+
constants.PRUNE_CUSTOM_PYTREE_NODES, False
109+
)
107110

108111
def _normalize_apply_fn_map(
109112
self,

export/orbax/export/obm_export_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from collections.abc import Mapping, Sequence
1616
import contextlib
1717
import os
18+
from typing import Any, Callable
1819

1920
from absl.testing import absltest
2021
from absl.testing import parameterized

model/orbax/experimental/model/core/python/tree_util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,19 @@ def unflatten(tree: Tree[Any], leaves: Iterable[T6]) -> Tree[T6]:
142142
except StopIteration:
143143
return result
144144
raise ValueError("After unflattening, there are still leaves left.")
145+
146+
147+
def prune_tree(tree: Tree[Any], wanted_node_type: Any) -> Tree[Any]:
148+
"""Prunes the tree by removing unwanted leaves."""
149+
if isinstance(tree, (tuple, list)):
150+
return tuple_or_list_constructor(tree)(
151+
prune_tree(x, wanted_node_type) for x in tree
152+
)
153+
elif isinstance(tree, dict):
154+
tree: Dict[str, Tree[Any]]
155+
return {k: prune_tree(v, wanted_node_type) for k, v in tree.items()}
156+
elif isinstance(tree, wanted_node_type):
157+
return tree
158+
else:
159+
# If the node type is None or not wanted, we return None.
160+
return None

model/orbax/experimental/model/core/python/tree_util_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,28 @@ def assert_int(x: Any) -> None:
7777
lambda: tree_util.assert_tree(assert_int, wrong_tree),
7878
)
7979

80+
def test_prune_tree(self):
81+
tree = (1, "a", [2, "b", {"c": 3, "d": "e"}], None)
82+
pruned_int = tree_util.prune_tree(tree, int)
83+
self.assertEqual(
84+
pruned_int, (1, None, [2, None, {"c": 3, "d": None}], None)
85+
)
86+
pruned_str = tree_util.prune_tree(tree, str)
87+
self.assertEqual(
88+
pruned_str, (None, "a", [None, "b", {"c": None, "d": "e"}], None)
89+
)
90+
pruned_int_str = tree_util.prune_tree(tree, (int, str))
91+
self.assertEqual(pruned_int_str, tree)
92+
pruned_empty = tree_util.prune_tree(tree, ())
93+
self.assertEqual(
94+
pruned_empty, (None, None, [None, None, {"c": None, "d": None}], None)
95+
)
96+
pruned_not_present = tree_util.prune_tree(tree, float)
97+
self.assertEqual(
98+
pruned_not_present,
99+
(None, None, [None, None, {"c": None, "d": None}], None),
100+
)
101+
80102

81103
if __name__ == "__main__":
82104
absltest.main()

model/orbax/experimental/model/jax2obm/jax_specific_info.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def _to_shlo_spec_tree_and_refinement_tuple(
311311
avals: Sequence[jax.core.AbstractValue],
312312
shardings: Sequence[Any],
313313
tree_def: Optional[jax.tree_util.PyTreeDef],
314+
prune_custom_pytree_nodes: bool = False,
314315
) -> Tuple[
315316
obm.Tree[obm.ShloTensorSpec], Tuple[ShapeDTypeRefinementPair, ...] | None
316317
]:
@@ -327,7 +328,8 @@ def assert_leaf(x: Any) -> None:
327328
raise ValueError(
328329
f"Leaf needs to be a ShloTensorSpec, but its type is: {type(x)}"
329330
)
330-
331+
if prune_custom_pytree_nodes:
332+
jax_tree = obm.tree_util.prune_tree(jax_tree, obm.ShloTensorSpec)
331333
obm.tree_util.assert_tree(assert_leaf, jax_tree)
332334
jax_tree: obm.Tree[obm.ShloTensorSpec]
333335
return jax_tree, refinements

model/orbax/experimental/model/jax2obm/main_lib.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def jax_exported_to_shlo_fn(
4141
exported: jax_export.Exported,
4242
xla_compile_options_per_platform: (
4343
obm.manifest_pb2.CompileOptionsProtoMap | None
44+
prune_custom_pytree_nodes: bool = False,
4445
) -> obm.ShloFunction:
4546
"""Converts a `jax.export.Exported` to an Orbax Model `ShloFunction`."""
4647

@@ -57,13 +58,15 @@ def jax_exported_to_shlo_fn(
5758
exported.in_avals,
5859
in_shardings_hlo,
5960
exported.in_tree,
61+
prune_custom_pytree_nodes=prune_custom_pytree_nodes,
6062
)
6163
)
6264
shlo_out_sig, jax_out_sig_refinements = (
6365
jax_specific_info._to_shlo_spec_tree_and_refinement_tuple(
6466
exported.out_avals,
6567
out_shardings_hlo,
6668
exported.out_tree,
69+
prune_custom_pytree_nodes=prune_custom_pytree_nodes,
6770
)
6871
)
6972
supplemental_info_ = {}
@@ -107,6 +110,7 @@ def convert(
107110
native_serialization_disabled_checks: Sequence[
108111
jax_export.DisabledSafetyCheck
109112
] = (),
113+
prune_custom_pytree_nodes: bool = False,
110114
) -> obm.ShloFunction:
111115
"""Converts a JAX function to an Orbax Model `ShloFunction`.
112116
@@ -138,6 +142,8 @@ def convert(
138142
model artifact, to ensure XLA compilation consistency and reproducibility
139143
between export time and serving time. Each map entry corresponds to a
140144
platform type (e.g. TPU, GPU, etc.).
145+
prune_custom_pytree_nodes: Optional. True if the custom pytree nodes should
146+
be pruned. False by default.
141147
142148
Returns:
143149
An Orbax Model `ShloFunction`.
@@ -156,6 +162,7 @@ def convert(
156162
exported = exported_creator(*args_spec, **kwargs_spec)
157163
return jax_exported_to_shlo_fn(
158164
exported,
165+
prune_custom_pytree_nodes=prune_custom_pytree_nodes,
159166
)
160167

161168

0 commit comments

Comments
 (0)