Skip to content

Commit 15f4a77

Browse files
committed
Apply changes from code review
1 parent 4026879 commit 15f4a77

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tfjs-converter/python/tensorflowjs/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ py_library(
5050
# This is a dummy rule used as a JAX dependency in open-source.
5151
# We expect JAX to already be installed on the system, e.g. via
5252
# `pip install jax`.
53-
deps = [requirement("jax")],
53+
deps = [
54+
requirement("jax"),
55+
requirement("importlib_resources"),
56+
],
5457
)
5558

5659
py_library(

tfjs-converter/python/tensorflowjs/converters/jax_conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from jax.experimental import jax2tf
2020
from jax.experimental.jax2tf import shape_poly
2121
import tensorflow as tf
22-
from tensorflowjs.converters import convert_tf_saved_model
22+
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion
2323

2424

2525
_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -146,4 +146,4 @@ def convert_jax(
146146
saved_model_dir,
147147
signatures=signatures,
148148
options=saved_model_options)
149-
convert_tf_saved_model(saved_model_dir, model_dir)
149+
saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir)

0 commit comments

Comments
 (0)