@@ -816,6 +816,39 @@ async def _maybe_deserialize(
816816 flat_restored , target = item
817817 )
818818
819+ def _partial_restore_with_omission (
820+ self , item : PyTree , value_metadata_tree : PyTree , restore_args : PyTree
821+ ) -> Tuple [PyTree , PyTree ]:
822+ """Restores leaves specified in `item`. Skips omitted leaves."""
823+ serialized_item = tree_utils .serialize_tree (item , keep_empty_nodes = True )
824+
825+ if not self ._pytree_metadata_options .support_rich_types :
826+ # Replace empty containers with scalar values (zeros). During saving,
827+ # some empty containers (like named tuples) were given
828+ # ValueMetadataEntries as if they were scalars. We normalize these
829+ # containers to scalars so that tree_trim is none the wiser.
830+ serialized_item = jax .tree .map (
831+ lambda v : 0 if empty_values .is_empty_container (v ) else v ,
832+ serialized_item ,
833+ is_leaf = tree_utils .is_empty_or_leaf ,
834+ )
835+
836+ try :
837+ value_metadata_tree = tree_structure_utils .tree_trim (
838+ serialized_item , value_metadata_tree , strict = True
839+ )
840+ except ValueError as e :
841+ raise ValueError (
842+ 'Missing keys were found in the user-provided restore item.'
843+ ) from e
844+
845+ if restore_args is not None :
846+ restore_args = tree_structure_utils .tree_trim (
847+ item , restore_args , strict = True
848+ )
849+
850+ return value_metadata_tree , restore_args
851+
819852 def restore (
820853 self ,
821854 directory : epath .Path ,
@@ -935,29 +968,9 @@ class TrainState:
935968 if item is None :
936969 item = value_metadata_tree
937970 elif args .partial_restore :
938- serialized_item = tree_utils .serialize_tree (item , keep_empty_nodes = True )
939- if not self ._pytree_metadata_options .support_rich_types :
940- # Replace empty containers with scalar values (zeros). During saving,
941- # some empty containers (like named tuples) were given
942- # ValueMetadataEntries as if they were scalars. We normalize these
943- # containers to scalars so that tree_trim is none the wiser.
944- serialized_item = jax .tree .map (
945- lambda v : 0 if empty_values .is_empty_container (v ) else v ,
946- serialized_item ,
947- is_leaf = tree_utils .is_empty_or_leaf ,
948- )
949- try :
950- value_metadata_tree = tree_structure_utils .tree_trim (
951- serialized_item , value_metadata_tree , strict = True
952- )
953- except ValueError as e :
954- raise ValueError (
955- 'Missing keys were found in the user-provided restore item.'
956- ) from e
957- if restore_args is not None :
958- restore_args = tree_structure_utils .tree_trim (
959- item , restore_args , strict = True
960- )
971+ value_metadata_tree , restore_args = self ._partial_restore_with_omission (
972+ item , value_metadata_tree , restore_args
973+ )
961974 else :
962975 # is_empty_or_leaf is necessary here to treat empty nodes (e.g. empty
963976 # dicts, lists, custom nodes) as leaves, as they do not contain any
0 commit comments