Skip to content

Commit 3a5ef80

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Refactor: Extract partial restore omission logic into a separate method.
PiperOrigin-RevId: 823234508
1 parent bcc1e34 commit 3a5ef80

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,39 @@ async def _maybe_deserialize(
816816
flat_restored, target=item
817817
)
818818

819+
def _partial_restore_with_omission(
820+
self, item: PyTree, value_metadata_tree: PyTree, restore_args: PyTree
821+
) -> Tuple[PyTree, PyTree]:
822+
"""Restores leaves specified in `item`. Skips omitted leaves."""
823+
serialized_item = tree_utils.serialize_tree(item, keep_empty_nodes=True)
824+
825+
if not self._pytree_metadata_options.support_rich_types:
826+
# Replace empty containers with scalar values (zeros). During saving,
827+
# some empty containers (like named tuples) were given
828+
# ValueMetadataEntries as if they were scalars. We normalize these
829+
# containers to scalars so that tree_trim is none the wiser.
830+
serialized_item = jax.tree.map(
831+
lambda v: 0 if empty_values.is_empty_container(v) else v,
832+
serialized_item,
833+
is_leaf=tree_utils.is_empty_or_leaf,
834+
)
835+
836+
try:
837+
value_metadata_tree = tree_structure_utils.tree_trim(
838+
serialized_item, value_metadata_tree, strict=True
839+
)
840+
except ValueError as e:
841+
raise ValueError(
842+
'Missing keys were found in the user-provided restore item.'
843+
) from e
844+
845+
if restore_args is not None:
846+
restore_args = tree_structure_utils.tree_trim(
847+
item, restore_args, strict=True
848+
)
849+
850+
return value_metadata_tree, restore_args
851+
819852
def restore(
820853
self,
821854
directory: epath.Path,
@@ -935,29 +968,9 @@ class TrainState:
935968
if item is None:
936969
item = value_metadata_tree
937970
elif args.partial_restore:
938-
serialized_item = tree_utils.serialize_tree(item, keep_empty_nodes=True)
939-
if not self._pytree_metadata_options.support_rich_types:
940-
# Replace empty containers with scalar values (zeros). During saving,
941-
# some empty containers (like named tuples) were given
942-
# ValueMetadataEntries as if they were scalars. We normalize these
943-
# containers to scalars so that tree_trim is none the wiser.
944-
serialized_item = jax.tree.map(
945-
lambda v: 0 if empty_values.is_empty_container(v) else v,
946-
serialized_item,
947-
is_leaf=tree_utils.is_empty_or_leaf,
948-
)
949-
try:
950-
value_metadata_tree = tree_structure_utils.tree_trim(
951-
serialized_item, value_metadata_tree, strict=True
952-
)
953-
except ValueError as e:
954-
raise ValueError(
955-
'Missing keys were found in the user-provided restore item.'
956-
) from e
957-
if restore_args is not None:
958-
restore_args = tree_structure_utils.tree_trim(
959-
item, restore_args, strict=True
960-
)
971+
value_metadata_tree, restore_args = self._partial_restore_with_omission(
972+
item, value_metadata_tree, restore_args
973+
)
961974
else:
962975
# is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
963976
# dicts, lists, custom nodes) as leaves, as they do not contain any

0 commit comments

Comments
 (0)