|
| 1 | +# Copyright 2022 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Converts a JAX function to TensorFlow.js web format.""" |
| 16 | +import tempfile |
| 17 | +from typing import Any, Callable, Optional, Sequence, Tuple, Union |
| 18 | + |
| 19 | +from jax.experimental import jax2tf |
| 20 | +from jax.experimental.jax2tf import shape_poly |
| 21 | +import tensorflow as tf |
| 22 | +from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion |
| 23 | + |
| 24 | + |
| 25 | +_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY |
| 26 | +Array = Any |
| 27 | +DType = Any |
| 28 | +PolyShape = shape_poly.PolyShape |
| 29 | + |
| 30 | + |
| 31 | +class _ReusableSavedModelWrapper(tf.train.Checkpoint): |
| 32 | + """Wraps a function and its parameters for saving to a SavedModel. |
| 33 | +
|
| 34 | + Implements the interface described at |
| 35 | + https://www.tensorflow.org/hub/reusable_saved_models. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, tf_graph, param_vars): |
| 39 | + """Initializes a _ReusableSavedModelWrapper. |
| 40 | +
|
| 41 | + Args: |
| 42 | + tf_graph: a tf.function taking one argument (the inputs), which can be |
| 43 | + be tuples/lists/dictionaries of np.ndarray or tensors. The function |
| 44 | + may have references to the tf.Variables in `param_vars`. |
| 45 | + param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable, |
| 46 | + to be saved as the variables of the SavedModel. |
| 47 | + """ |
| 48 | + super().__init__() |
| 49 | + self.variables = tf.nest.flatten(param_vars) |
| 50 | + self.trainable_variables = [v for v in self.variables if v.trainable] |
| 51 | + # If you intend to prescribe regularization terms for users of the model, |
| 52 | + # add them as @tf.functions with no inputs to this list. Else drop this. |
| 53 | + self.regularization_losses = [] |
| 54 | + self.__call__ = tf_graph |
| 55 | + |
| 56 | + |
| 57 | +def convert_jax( |
| 58 | + apply_fn: Callable[..., Any], |
| 59 | + params: Array, |
| 60 | + *, |
| 61 | + input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]], |
| 62 | + model_dir: str, |
| 63 | + polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None): |
| 64 | + """Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model. |
| 65 | +
|
| 66 | + Example usage for a Flax Module: |
| 67 | +
|
| 68 | + ``` |
| 69 | + import numpy as np |
| 70 | + from flax import linen as nn |
| 71 | + from jax import random |
| 72 | + import jax.numpy as jnp |
| 73 | + from tensorflowjs.converters.jax_conversion import convert_jax |
| 74 | +
|
| 75 | + module = nn.Dense(features=4) |
| 76 | + inputs = jnp.ones((3, 4)) |
| 77 | + params = module.init(random.PRNKey(0), inputs)['params'] |
| 78 | +
|
| 79 | + convert_jax( |
| 80 | + apply_fn=module.apply, |
| 81 | + params=params, |
| 82 | + input_signatures=[((3, 4), np.float32)], |
| 83 | + model_dir=tfjs_model_dir) |
| 84 | + ``` |
| 85 | +
|
| 86 | + Note that when using dynamic shapes, an additional argument |
| 87 | + `polymorphic_shapes` should be provided specifying values for the dynamic |
| 88 | + ("polymorphic") dimensions). See here for more details: |
| 89 | + https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion |
| 90 | +
|
| 91 | + This is an adaption of the original implementation in jax2tf here: |
| 92 | + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py |
| 93 | +
|
| 94 | + Arguments: |
| 95 | + apply_fn: A JAX function that has one or more arguments, of which the first |
| 96 | + argument are the model parameters. This function typically is the forward |
| 97 | + pass of the network (e.g., `Module.apply()` in Flax). |
| 98 | + params: A Pytree containing the parameters of the module. These will all be |
| 99 | + converted to TF.Variables. |
| 100 | + input_signatures: the input signatures for the second and remaining |
| 101 | + arguments to `apply_fn` (the input). A signature must be a |
| 102 | + `tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary |
| 103 | + thereof with a structure matching the second argument of `apply_fn`. |
| 104 | + model_dir: Directory where the TensorflowJS model will be written to. |
| 105 | + polymorphic_shapes: If given then it will be used as the |
| 106 | + `polymorphic_shapes` argument for the second parameter of `apply_fn`. In |
| 107 | + this case, a single `input_signatures` is supported, and should have |
| 108 | + `None` in the polymorphic (dynamic) dimensions. |
| 109 | + """ |
| 110 | + if polymorphic_shapes is not None: |
| 111 | + # If polymorphic shapes are provided, add a polymorphic spec for the |
| 112 | + # first argument to `apply_fn`, which are the parameters. |
| 113 | + polymorphic_shapes = [None, *polymorphic_shapes] |
| 114 | + |
| 115 | + tf_fn = jax2tf.convert( |
| 116 | + apply_fn, |
| 117 | + # Gradients must be included as 'PreventGradient' is not supported. |
| 118 | + with_gradient=True, |
| 119 | + polymorphic_shapes=polymorphic_shapes, |
| 120 | + # Do not use TFXLA Ops because these aren't supported by TFjs, but use |
| 121 | + # workarounds instead. More information: |
| 122 | + # https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops |
| 123 | + enable_xla=False) |
| 124 | + |
| 125 | + # Create tf.Variables for the parameters. If you want more useful variable |
| 126 | + # names, you can use `tree.map_structure_with_path` from the `dm-tree` |
| 127 | + # package. |
| 128 | + param_vars = tf.nest.map_structure( |
| 129 | + lambda param: tf.Variable(param, trainable=True), params) |
| 130 | + # Do not use TF's jit compilation on the function. |
| 131 | + tf_graph = tf.function( |
| 132 | + lambda *xs: tf_fn(param_vars, *xs), autograph=False, jit_compile=False) |
| 133 | + |
| 134 | + # This signature is needed for TensorFlow Serving use. |
| 135 | + signatures = { |
| 136 | + _TF_SERVING_KEY: tf_graph.get_concrete_function(*input_signatures) |
| 137 | + } |
| 138 | + |
| 139 | + wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars) |
| 140 | + saved_model_options = tf.saved_model.SaveOptions( |
| 141 | + experimental_custom_gradients=True) |
| 142 | + |
| 143 | + with tempfile.TemporaryDirectory() as saved_model_dir: |
| 144 | + tf.saved_model.save( |
| 145 | + wrapper, |
| 146 | + saved_model_dir, |
| 147 | + signatures=signatures, |
| 148 | + options=saved_model_options) |
| 149 | + saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir) |
0 commit comments