Skip to content

Commit 9b55a1c

Browse files
nkovela1tensorflower-gardener
authored andcommitted
Increases ExtensionType support coverage for v3 Keras saving, including MaskedTensor support.
PiperOrigin-RevId: 529893705
1 parent 8261e7f commit 9b55a1c

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

keras/integration_test/extension_type_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test Model inference and save/load with an ExtensionType."""
22

3+
import os
34
import typing
45

56
import tensorflow.compat.v2 as tf
@@ -89,6 +90,13 @@ def testKerasModel(self):
8990
serving_fn(args_0=mt.values, args_0_1=mt.mask)["lambda"], mt
9091
)
9192

93+
with self.subTest("keras v3"):
94+
path = os.path.join(self.create_tempdir().full_path, "model.keras")
95+
model.save(path)
96+
loaded_model = load_model(path, safe_mode=False)
97+
self.assertEqual(loaded_model.input.type_spec, mt_spec)
98+
self.assertEqual(loaded_model(mt), mt)
99+
92100

93101
if __name__ == "__main__":
94102
tf.test.main()

keras/saving/serialization_lib.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,21 @@ def serialize_keras_object(obj):
220220
ts_config,
221221
)
222222
)
223+
spec_name = obj.__class__.__name__
224+
registered_name = None
225+
if hasattr(obj, "_tf_extension_type_fields"):
226+
# Special casing for ExtensionType
227+
ts_config = tf.experimental.extension_type.as_dict(obj)
228+
ts_config = serialize_dict(ts_config)
229+
registered_name = object_registration.get_registered_name(
230+
obj.__class__
231+
)
223232
return {
224233
"class_name": "__typespec__",
225-
"spec_name": obj.__class__.__name__,
234+
"spec_name": spec_name,
226235
"module": obj.__class__.__module__,
227236
"config": ts_config,
228-
"registered_name": None,
237+
"registered_name": registered_name,
229238
}
230239

231240
inner_config = _get_class_or_fn_config(obj)
@@ -638,23 +647,38 @@ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
638647
"the loading function in order to allow `lambda` loading."
639648
)
640649
return generic_utils.func_load(inner_config["value"])
650+
641651
if config["class_name"] == "__typespec__":
642-
obj = _retrieve_class_or_fn(
652+
cls = _retrieve_class_or_fn(
643653
config["spec_name"],
644654
config["registered_name"],
645655
config["module"],
646656
obj_type="class",
647657
full_config=config,
648658
custom_objects=custom_objects,
649659
)
660+
661+
# Special casing for ExtensionType.Spec
662+
if hasattr(cls, "_tf_extension_type_fields"):
663+
inner_config = {
664+
key: deserialize_keras_object(
665+
value, custom_objects=custom_objects, safe_mode=safe_mode
666+
)
667+
for key, value in inner_config.items()
668+
} # Deserialization of dict created by ExtensionType.as_dict()
669+
return cls(**inner_config) # Instantiate ExtensionType.Spec
670+
671+
if config["registered_name"] is not None:
672+
return cls.from_config(inner_config)
673+
650674
# Conversion to TensorShape and tf.DType
651675
inner_config = map(
652676
lambda x: tf.TensorShape(x)
653677
if isinstance(x, list)
654678
else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x),
655679
inner_config,
656680
)
657-
return obj._deserialize(tuple(inner_config))
681+
return cls._deserialize(tuple(inner_config))
658682

659683
# Below: classes and functions.
660684
module = config.get("module", None)
@@ -782,9 +806,20 @@ def _retrieve_class_or_fn(
782806
)
783807
obj = vars(mod).get(name, None)
784808

785-
# Special case for keras.metrics.metrics
786-
if obj is None and registered_name is not None:
787-
obj = vars(mod).get(registered_name, None)
809+
if obj is None:
810+
# Special case for keras.metrics.metrics
811+
if registered_name is not None:
812+
obj = vars(mod).get(registered_name, None)
813+
814+
# Support for `__qualname__`
815+
if name.count(".") == 1:
816+
outer_name, inner_name = name.split(".")
817+
outer_obj = vars(mod).get(outer_name, None)
818+
obj = (
819+
getattr(outer_obj, inner_name, None)
820+
if outer_obj is not None
821+
else None
822+
)
788823

789824
if obj is not None:
790825
return obj

0 commit comments

Comments
 (0)