@@ -220,12 +220,21 @@ def serialize_keras_object(obj):
220
220
ts_config ,
221
221
)
222
222
)
223
+ spec_name = obj .__class__ .__name__
224
+ registered_name = None
225
+ if hasattr (obj , "_tf_extension_type_fields" ):
226
+ # Special casing for ExtensionType
227
+ ts_config = tf .experimental .extension_type .as_dict (obj )
228
+ ts_config = serialize_dict (ts_config )
229
+ registered_name = object_registration .get_registered_name (
230
+ obj .__class__
231
+ )
223
232
return {
224
233
"class_name" : "__typespec__" ,
225
- "spec_name" : obj . __class__ . __name__ ,
234
+ "spec_name" : spec_name ,
226
235
"module" : obj .__class__ .__module__ ,
227
236
"config" : ts_config ,
228
- "registered_name" : None ,
237
+ "registered_name" : registered_name ,
229
238
}
230
239
231
240
inner_config = _get_class_or_fn_config (obj )
@@ -638,23 +647,38 @@ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
638
647
"the loading function in order to allow `lambda` loading."
639
648
)
640
649
return generic_utils .func_load (inner_config ["value" ])
650
+
641
651
if config ["class_name" ] == "__typespec__" :
642
- obj = _retrieve_class_or_fn (
652
+ cls = _retrieve_class_or_fn (
643
653
config ["spec_name" ],
644
654
config ["registered_name" ],
645
655
config ["module" ],
646
656
obj_type = "class" ,
647
657
full_config = config ,
648
658
custom_objects = custom_objects ,
649
659
)
660
+
661
+ # Special casing for ExtensionType.Spec
662
+ if hasattr (cls , "_tf_extension_type_fields" ):
663
+ inner_config = {
664
+ key : deserialize_keras_object (
665
+ value , custom_objects = custom_objects , safe_mode = safe_mode
666
+ )
667
+ for key , value in inner_config .items ()
668
+ } # Deserialization of dict created by ExtensionType.as_dict()
669
+ return cls (** inner_config ) # Instantiate ExtensionType.Spec
670
+
671
+ if config ["registered_name" ] is not None :
672
+ return cls .from_config (inner_config )
673
+
650
674
# Conversion to TensorShape and tf.DType
651
675
inner_config = map (
652
676
lambda x : tf .TensorShape (x )
653
677
if isinstance (x , list )
654
678
else (getattr (tf , x ) if hasattr (tf .dtypes , str (x )) else x ),
655
679
inner_config ,
656
680
)
657
- return obj ._deserialize (tuple (inner_config ))
681
+ return cls ._deserialize (tuple (inner_config ))
658
682
659
683
# Below: classes and functions.
660
684
module = config .get ("module" , None )
@@ -782,9 +806,20 @@ def _retrieve_class_or_fn(
782
806
)
783
807
obj = vars (mod ).get (name , None )
784
808
785
- # Special case for keras.metrics.metrics
786
- if obj is None and registered_name is not None :
787
- obj = vars (mod ).get (registered_name , None )
809
+ if obj is None :
810
+ # Special case for keras.metrics.metrics
811
+ if registered_name is not None :
812
+ obj = vars (mod ).get (registered_name , None )
813
+
814
+ # Support for `__qualname__`
815
+ if name .count ("." ) == 1 :
816
+ outer_name , inner_name = name .split ("." )
817
+ outer_obj = vars (mod ).get (outer_name , None )
818
+ obj = (
819
+ getattr (outer_obj , inner_name , None )
820
+ if outer_obj is not None
821
+ else None
822
+ )
788
823
789
824
if obj is not None :
790
825
return obj
0 commit comments