Skip to content

Commit 167fd6f

Browse files
marcvanzeepyu10055
andauthored
Adds JAX-->TFjs converter (#6744)
FEATURE * Adds JAX-->TFjs converter * Apply changes from code review Co-authored-by: Ping Yu <[email protected]>
1 parent d515f4d commit 167fd6f

File tree

9 files changed

+416
-6
lines changed

9 files changed

+416
-6
lines changed

tfjs-converter/README.md

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using an already hosted model (e.g. MobileNet), skip this step.
1717
2. [JavaScript API](./src/executor/tf_model.ts), for loading and running
1818
inference.
1919

20-
## Step 1: Converting a [TensorFlow SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md), [TensorFlow Hub module](https://www.tensorflow.org/hub/), [Keras HDF5](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model) or [tf.keras SavedModel](https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model) to a web-friendly format
20+
## Step 1: Converting a [TensorFlow SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md), [TensorFlow Hub module](https://www.tensorflow.org/hub/), [Keras HDF5](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model), [tf.keras SavedModel](https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model), or [Flax/JAX model](http://github.com/google/flax) to a web-friendly format
2121

2222
__0. Please make sure that you run in a Docker container or a virtual environment.__
2323

@@ -54,10 +54,13 @@ Install the library with interactive CLI:
5454

5555
__2. Run the converter script provided by the pip package:__
5656

57-
There are two way to trigger the model conversion:
57+
There are three way to trigger the model conversion, explain below:
5858

59-
- The conversion wizard: `tensorflowjs_wizard`
60-
- Regular conversion script: `tensorflowjs_converter`
59+
- The conversion wizard: `tensorflowjs_wizard` ([go to section](#conversion-wizard-tensorflowjswizard))
60+
- Regular conversion script: `tensorflowjs_converter` ([go to section](#regular-conversion-script-tensorflowjsconverter))
61+
- Calling a converter function in Python (Flax/JAX) ([go to section](#calling-a-converter-function-in-python))
62+
63+
## Conversion wizard: `tensorflowjs_wizard`
6164

6265
To start the conversion wizard:
6366
```bash
@@ -81,7 +84,7 @@ tensorflowjs_wizard --dryrun
8184
To convert a batch of models or integrate the conversion process into your own
8285
script, you should use the tensorflowjs_converter script.
8386

84-
## Conversion flags
87+
## Regular conversion script: `tensorflowjs_converter`
8588

8689
The converter expects a __TensorFlow SavedModel__, __TensorFlow Hub module__,
8790
__TensorFlow.js JSON__ format, __Keras HDF5 model__, or __tf.keras SavedModel__
@@ -141,6 +144,8 @@ Note that the input path used above is a subfolder that has a Unix epoch
141144
time (1542211770) and is generated automatically by tensorflow when it
142145
saved a tf.keras model in the SavedModel format.
143146

147+
### Conversion Flags
148+
144149
|Positional Arguments | Description |
145150
|---|---|
146151
|`input_path` | Full path of the saved model directory or TensorFlow Hub module handle or path.|
@@ -271,6 +276,53 @@ following location:
271276
https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard5of5
272277
```
273278

279+
## Calling a Converter Function in Python (Flax/JAX)
280+
281+
You can also convert your model to web format in Python by calling one of the
282+
conversion functions. This is currently the only way to convert a Flax or JAX
283+
model, since no standard serialization format exists to store a Module (only the
284+
checkpoints).
285+
286+
Here we provide an example of how to convert a Flax function using the
287+
conversion function `tfjs.jax_conversion.convert_jax()`.
288+
289+
```py
290+
import numpy as np
291+
from flax import linen as nn
292+
from jax import random
293+
import jax.numpy as jnp
294+
from tensorflowjs.converters import jax_conversion
295+
296+
module = nn.Dense(features=4)
297+
inputs = jnp.ones((3, 4))
298+
params = module.init(random.PRNKey(0), inputs)['params']
299+
300+
jax_conversion.convert_jax(
301+
apply_fn=module.apply,
302+
params=params,
303+
input_signatures=[((3, 4), np.float32)],
304+
model_dir=tfjs_model_dir)
305+
```
306+
307+
Note that when using dynamic shapes, an additional argument `polymorphic_shapes`
308+
should be provided specifying values for the dynamic ("polymorphic")
309+
dimensions). So in order to convert the same model as before, but now with a
310+
dynamic first dimension, one should call `convert_jax` as follows:
311+
312+
```py
313+
jax_conversion.convert_jax(
314+
apply_fn=module.apply,
315+
params=params,
316+
input_signatures=[((None, 4), np.float32)],
317+
polymorphic_shapes=["(b, 4)"],
318+
model_dir=tfjs_model_dir)
319+
```
320+
321+
See
322+
[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)
323+
for more details on the exact syntax for this argument.
324+
325+
274326
## Step 2: Loading and running in the browser
275327

276328
If the original model was a `SavedModel`, use

tfjs-converter/python/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ py_wheel(
4848
license = "Apache 2.0",
4949
python_tag = "py3",
5050
requires = [
51+
"flax>=0.5.3",
52+
"importlib_resources>=5.9.0",
53+
"jax>=0.3.16",
5154
"protobuf<3.20,>=3.9.2",
5255
"tensorflow>=2.1.0,<3",
5356
"six>=1.12.0,<2",

tfjs-converter/python/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
flax>=0.5.3
2+
jax>=0.3.16
3+
importlib_resources>=5.9.0
14
protobuf<3.20,>=3.9.2
25
tensorflow>=2.1.0,<3
36
six>=1.12.0,<2

tfjs-converter/python/tensorflowjs/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ py_library(
2929
],
3030
)
3131

32+
py_library(
33+
name = "expect_flax_installed",
34+
# This is a dummy rule used as a Flax dependency in open-source.
35+
# We expect JAX to already be installed on the system, e.g. via
36+
# `pip install flax`.
37+
deps = [requirement("flax")],
38+
)
39+
3240
py_library(
3341
name = "expect_h5py_installed",
3442
# This is a dummy rule used as a h5py dependency in open-source.
@@ -37,6 +45,25 @@ py_library(
3745
deps = [requirement("h5py")],
3846
)
3947

48+
py_library(
49+
name = "expect_jax_installed",
50+
# This is a dummy rule used as a JAX dependency in open-source.
51+
# We expect JAX to already be installed on the system, e.g. via
52+
# `pip install jax`.
53+
deps = [
54+
requirement("jax"),
55+
requirement("importlib_resources"),
56+
],
57+
)
58+
59+
py_library(
60+
name = "expect_jax2tf_installed",
61+
# This is a dummy rule used as a jax2tf dependency in open-source.
62+
# We expect jax2tf to already be installed on the system, e.g. via
63+
# `pip install jax`.
64+
deps = [requirement("jax")],
65+
)
66+
4067
py_library(
4168
name = "expect_numpy_installed",
4269
# This is a dummy rule used as a numpy dependency in open-source.

tfjs-converter/python/tensorflowjs/converters/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,32 @@ py_library(
195195
],
196196
)
197197

198+
py_test(
199+
name = "jax_conversion_test",
200+
srcs = ["jax_conversion_test.py"],
201+
imports = ["../.."],
202+
srcs_version = "PY3",
203+
tags = ["ci"],
204+
deps = [
205+
":jax_conversion",
206+
"//tfjs-converter/python/tensorflowjs:expect_flax_installed",
207+
"//tfjs-converter/python/tensorflowjs:expect_jax_installed",
208+
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
209+
],
210+
)
211+
212+
py_library(
213+
name = "jax_conversion",
214+
srcs = ["jax_conversion.py"],
215+
srcs_version = "PY3",
216+
deps = [
217+
":tf_saved_model_conversion_v2",
218+
"//tfjs-converter/python/tensorflowjs:expect_jax2tf_installed",
219+
"//tfjs-converter/python/tensorflowjs:expect_jax_installed",
220+
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
221+
],
222+
)
223+
198224
py_test(
199225
name = "wizard_test",
200226
srcs = ["wizard_test.py"],

tfjs-converter/python/tensorflowjs/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from tensorflowjs.converters.keras_tfjs_loader import deserialize_keras_model
2424
from tensorflowjs.converters.keras_tfjs_loader import load_keras_model
2525
from tensorflowjs.converters.tf_saved_model_conversion_v2 import convert_tf_saved_model
26+
from tensorflowjs.converters.jax_conversion import convert_jax
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Converts a JAX function to TensorFlow.js web format."""
16+
import tempfile
17+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
18+
19+
from jax.experimental import jax2tf
20+
from jax.experimental.jax2tf import shape_poly
21+
import tensorflow as tf
22+
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion
23+
24+
25+
_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
26+
Array = Any
27+
DType = Any
28+
PolyShape = shape_poly.PolyShape
29+
30+
31+
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
32+
"""Wraps a function and its parameters for saving to a SavedModel.
33+
34+
Implements the interface described at
35+
https://www.tensorflow.org/hub/reusable_saved_models.
36+
"""
37+
38+
def __init__(self, tf_graph, param_vars):
39+
"""Initializes a _ReusableSavedModelWrapper.
40+
41+
Args:
42+
tf_graph: a tf.function taking one argument (the inputs), which can be
43+
be tuples/lists/dictionaries of np.ndarray or tensors. The function
44+
may have references to the tf.Variables in `param_vars`.
45+
param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
46+
to be saved as the variables of the SavedModel.
47+
"""
48+
super().__init__()
49+
self.variables = tf.nest.flatten(param_vars)
50+
self.trainable_variables = [v for v in self.variables if v.trainable]
51+
# If you intend to prescribe regularization terms for users of the model,
52+
# add them as @tf.functions with no inputs to this list. Else drop this.
53+
self.regularization_losses = []
54+
self.__call__ = tf_graph
55+
56+
57+
def convert_jax(
58+
apply_fn: Callable[..., Any],
59+
params: Array,
60+
*,
61+
input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
62+
model_dir: str,
63+
polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None):
64+
"""Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.
65+
66+
Example usage for a Flax Module:
67+
68+
```
69+
import numpy as np
70+
from flax import linen as nn
71+
from jax import random
72+
import jax.numpy as jnp
73+
from tensorflowjs.converters.jax_conversion import convert_jax
74+
75+
module = nn.Dense(features=4)
76+
inputs = jnp.ones((3, 4))
77+
params = module.init(random.PRNKey(0), inputs)['params']
78+
79+
convert_jax(
80+
apply_fn=module.apply,
81+
params=params,
82+
input_signatures=[((3, 4), np.float32)],
83+
model_dir=tfjs_model_dir)
84+
```
85+
86+
Note that when using dynamic shapes, an additional argument
87+
`polymorphic_shapes` should be provided specifying values for the dynamic
88+
("polymorphic") dimensions). See here for more details:
89+
https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion
90+
91+
This is an adaption of the original implementation in jax2tf here:
92+
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py
93+
94+
Arguments:
95+
apply_fn: A JAX function that has one or more arguments, of which the first
96+
argument are the model parameters. This function typically is the forward
97+
pass of the network (e.g., `Module.apply()` in Flax).
98+
params: A Pytree containing the parameters of the module. These will all be
99+
converted to TF.Variables.
100+
input_signatures: the input signatures for the second and remaining
101+
arguments to `apply_fn` (the input). A signature must be a
102+
`tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary
103+
thereof with a structure matching the second argument of `apply_fn`.
104+
model_dir: Directory where the TensorflowJS model will be written to.
105+
polymorphic_shapes: If given then it will be used as the
106+
`polymorphic_shapes` argument for the second parameter of `apply_fn`. In
107+
this case, a single `input_signatures` is supported, and should have
108+
`None` in the polymorphic (dynamic) dimensions.
109+
"""
110+
if polymorphic_shapes is not None:
111+
# If polymorphic shapes are provided, add a polymorphic spec for the
112+
# first argument to `apply_fn`, which are the parameters.
113+
polymorphic_shapes = [None, *polymorphic_shapes]
114+
115+
tf_fn = jax2tf.convert(
116+
apply_fn,
117+
# Gradients must be included as 'PreventGradient' is not supported.
118+
with_gradient=True,
119+
polymorphic_shapes=polymorphic_shapes,
120+
# Do not use TFXLA Ops because these aren't supported by TFjs, but use
121+
# workarounds instead. More information:
122+
# https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops
123+
enable_xla=False)
124+
125+
# Create tf.Variables for the parameters. If you want more useful variable
126+
# names, you can use `tree.map_structure_with_path` from the `dm-tree`
127+
# package.
128+
param_vars = tf.nest.map_structure(
129+
lambda param: tf.Variable(param, trainable=True), params)
130+
# Do not use TF's jit compilation on the function.
131+
tf_graph = tf.function(
132+
lambda *xs: tf_fn(param_vars, *xs), autograph=False, jit_compile=False)
133+
134+
# This signature is needed for TensorFlow Serving use.
135+
signatures = {
136+
_TF_SERVING_KEY: tf_graph.get_concrete_function(*input_signatures)
137+
}
138+
139+
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
140+
saved_model_options = tf.saved_model.SaveOptions(
141+
experimental_custom_gradients=True)
142+
143+
with tempfile.TemporaryDirectory() as saved_model_dir:
144+
tf.saved_model.save(
145+
wrapper,
146+
saved_model_dir,
147+
signatures=signatures,
148+
options=saved_model_options)
149+
saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir)

0 commit comments

Comments
 (0)