Skip to content

Commit 45292eb

Browse files
authored
Fix Keras serialization naming collision (#2001)
* Structure for Namex export * Switch to programmatic API generation with Namex * Fix missing import * Fix preset class names * Fix serialization collision
1 parent 1f667e6 commit 45292eb

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

keras_cv/api_export.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,28 @@
2222
namex = None
2323

2424

25-
def maybe_register_serializable(symbol):
25+
def maybe_register_serializable(symbol, package):
2626
if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"):
27-
keras.saving.register_keras_serializable(package="keras_cv")(symbol)
27+
keras.saving.register_keras_serializable(package=package)(symbol)
2828

2929

3030
if namex:
3131

3232
class keras_cv_export(namex.export):
33-
def __init__(self, path):
33+
def __init__(self, path, package="keras_cv"):
3434
super().__init__(package="keras_cv", path=path)
35+
self.package = package
3536

3637
def __call__(self, symbol):
37-
maybe_register_serializable(symbol)
38+
maybe_register_serializable(symbol, self.package)
3839
return super().__call__(symbol)
3940

4041
else:
4142

4243
class keras_cv_export:
43-
def __init__(self, path):
44-
pass
44+
def __init__(self, path, package="keras_cv"):
45+
self.package = package
4546

4647
def __call__(self, symbol):
47-
maybe_register_serializable(symbol)
48+
maybe_register_serializable(symbol, self.package)
4849
return symbol

keras_cv/layers/feature_pyramid.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from keras_cv.api_export import keras_cv_export
1818

1919

20-
# TODO(scottzhu): Register it later due to the conflict in the retinanet
21-
# @keras.utils.register_keras_serializable(package="keras_cv")
2220
@keras_cv_export("keras_cv.layers.FeaturePyramid")
2321
class FeaturePyramid(keras.layers.Layer):
2422
"""Implements a Feature Pyramid Network.

keras_cv/models/object_detection/retinanet/feature_pyramid.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from keras_cv.backend import ops
1818

1919

20-
@keras_cv_export("keras_cv.models.retinanet.FeaturePyramid")
20+
@keras_cv_export(
21+
"keras_cv.models.retinanet.FeaturePyramid",
22+
package="keras_cv.models.retinanet",
23+
)
2124
class FeaturePyramid(keras.layers.Layer):
2225
"""Builds the Feature Pyramid with the feature maps from the backbone."""
2326

0 commit comments

Comments
 (0)