From 09b425f156890caf0ac5baa1ed178509f4ccd50c Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 16:04:08 -0800 Subject: [PATCH 01/34] Initial native function use Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 18 + .../internal/c_api/TF_Function.java | 33 +- .../java/org/tensorflow/ConcreteFunction.java | 378 +++++++++++++----- .../java/org/tensorflow/EagerSession.java | 11 + .../org/tensorflow/ExecutionEnvironment.java | 8 + .../src/main/java/org/tensorflow/Graph.java | 11 + .../java/org/tensorflow/SavedModelBundle.java | 11 +- .../main/java/org/tensorflow/Signature.java | 48 ++- .../internal/c_api/AbstractTF_Function.java | 57 +++ .../internal/c_api/presets/tensorflow.java | 291 ++++++++------ .../java/org/tensorflow/op/core/Function.java | 45 +++ .../org/tensorflow/ConcreteFunctionTest.java | 39 -- .../org/tensorflow/SavedModelBundleTest.java | 8 +- .../org/tensorflow/op/core/FunctionTest.java | 70 ++++ 14 files changed, 723 insertions(+), 305 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 92e4cabdbd1..53e87937a96 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -19,6 +19,8 @@ import java.nio.charset.Charset; import java.util.List; +import java.util.Map; +import org.tensorflow.ConcreteFunction; import org.tensorflow.DeviceSpec; import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; @@ -87,6 +89,7 @@ import org.tensorflow.op.core.ExtractVolumePatches; import org.tensorflow.op.core.Fill; import org.tensorflow.op.core.Fingerprint; +import org.tensorflow.op.core.Function; import org.tensorflow.op.core.Gather; import org.tensorflow.op.core.GatherNd; import org.tensorflow.op.core.GetSessionHandle; @@ -1116,6 +1119,21 @@ public Bucketize bucketize(Operand input, List boundar return Bucketize.create(scope, input, boundaries); } + /** + * empty + */ + public Operand call(ConcreteFunction function, Operand argument) { + return Function.call(scope, function, argument); + } + + /** + * empty + */ + public Map> call(ConcreteFunction function, + Map> arguments) { + return Function.call(scope, function, arguments); + } + /** * Clips tensor values to a specified min and max. * Given a tensor {@code t}, this operation returns a tensor of the same type and diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java index e370b2f9f08..df77798beeb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java @@ -2,20 +2,29 @@ package org.tensorflow.internal.c_api; -import java.nio.*; -import org.bytedeco.javacpp.*; -import org.bytedeco.javacpp.annotation.*; - -import static org.tensorflow.internal.c_api.global.tensorflow.*; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Opaque; +import org.bytedeco.javacpp.annotation.Properties; // TF_Function is a grouping of operations with defined inputs and outputs. // Once created and added to graphs, functions can be invoked by creating an // operation whose operation type matches the function name. -@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) -public class TF_Function extends Pointer { - /** Empty constructor. Calls {@code super((Pointer)null)}. */ - public TF_Function() { super((Pointer)null); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public TF_Function(Pointer p) { super(p); } -} +@Opaque +@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TF_Function extends org.tensorflow.internal.c_api.AbstractTF_Function { + + /** + * Empty constructor. Calls {@code super((Pointer)null)}. + */ + public TF_Function() { + super((Pointer) null); + } + + /** + * Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. + */ + public TF_Function(Pointer p) { + super(p); + } +} \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 71dc0f7cefc..95ac8360db3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -15,15 +15,33 @@ */ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; + import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; -import java.util.ListIterator; -import java.util.HashMap; import java.util.Map; import java.util.function.Function; +import java.util.stream.Collectors; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.PointerPointer; +import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.Graph.Reference; +import org.tensorflow.internal.c_api.TF_Function; +import org.tensorflow.internal.c_api.TF_Operation; +import org.tensorflow.internal.c_api.TF_Output; +import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.proto.framework.AttrValue; import org.tensorflow.proto.framework.SignatureDef; -import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.types.family.TType; /** * A graph that can be invoked as a single function, with an input and output signature. @@ -39,16 +57,87 @@ */ public class ConcreteFunction implements AutoCloseable { + private static TF_Operation outputHandle(Operand operand) { + if (operand == null) { + throw new NullPointerException("Can't get output handle for null operand"); + } + + Pointer handle = operand.asOutput().getUnsafeNativeHandle(); + if (handle.isNull()) { + throw new NullPointerException("Native handle of operand is null, has it been closed?"); + } + + if (!(handle instanceof TF_Operation)) { + throw new IllegalArgumentException("Operand was not a graph operand"); + } + + return (TF_Operation) handle; + } + + private static TF_Output resolveToOutput(Graph graph, List> operands) { + TF_Output handles = new TF_Output(operands.size()); + for (int i = 0; i < operands.size(); i++) { + Operand input = operands.get(i); + graph.checkInput(input); + TF_Operation handle = outputHandle(input); + handles.position(i).oper(handle).index(input.asOutput().index()); + } + handles.position(0); + return handles; + } + + private static TF_Function createNative(Graph graph, Signature signature) { + try (PointerScope scope = new PointerScope(); + Reference ref = graph.ref()) { + TF_Status status = TF_Status.newStatus(); + + List> inputs = signature.getInputs().values().stream() + .map((x) -> graph.outputOrError(x.name)) + .collect(Collectors.toList()); + + List> outputs = signature.getOutputs().values().stream() + .map((x) -> graph.outputOrError(x.name)) + .collect(Collectors.toList()); + + List ops = new ArrayList<>(graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs))); + + PointerPointer operations = new PointerPointer<>(ops.size()); + for (int i = 0; i < ops.size(); i++) { + operations.put(i, ops.get(i).getUnsafeNativeHandle()); + } + + TF_Function handle = TF_GraphToFunction( + ref.nativeHandle(), + new BytePointer(signature.key()), //TODO or methodName? + (byte) 1, + ops.size(), + operations, + inputs.size(), + resolveToOutput(graph, inputs), + outputs.size(), + resolveToOutput(graph, outputs), + null, + null, + new BytePointer(signature.methodName() != null ? signature.methodName() : "Method " + signature.key()), + status + ); + + status.throwExceptionIfNotOK(); + + return handle.withDeallocator(); + } + } + /** * Creates a function by building a new graph. * *

The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors - * and fetch the output tensors on execution. + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output + * tensors on execution. * *

The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that - * all native resources will be freed once the function is discarded. For example: + * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will + * be freed once the function is discarded. For example: * *

{@code
    * public class MyModel {
@@ -72,14 +161,11 @@ public class ConcreteFunction implements AutoCloseable {
    * @return the new function
    */
   public static ConcreteFunction create(Function functionBuilder) {
-    Graph graph = new Graph();
-    try {
+    //TODO make sure this works oki with graph closing
+    try (Graph graph = new Graph()) {
       Ops tf = Ops.create(graph);
       Signature signature = functionBuilder.apply(tf);
-      return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION);
-    } catch (Exception e) {
-      graph.close();
-      throw e;
+      return new ConcreteFunction(signature, graph);
     }
   }
 
@@ -87,8 +173,8 @@ public static ConcreteFunction create(Function functionBuilder)
    * Create a function from a signature and an existing graph.
    *
    * 

The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope - * of the function. For example: + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For + * example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -109,15 +195,15 @@ public static ConcreteFunction create(Function functionBuilder)
    * @return a new function
    */
   public static ConcreteFunction create(Signature signature, Graph graph) {
-    return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY);
+    return new ConcreteFunction(signature, graph);
   }
 
   /**
    * Create a function from a signature and a valid graph session.
    *
    * 

The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be - * closed after its usage. For example: + * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For + * example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -143,7 +229,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
    * @return a new function
    */
   public static ConcreteFunction create(Signature signature, Session session) {
-    return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE);
+    return new ConcreteFunction(signature, session.graph());
   }
 
   /**
@@ -154,121 +240,178 @@ public Signature signature() {
   }
 
   /**
-   * Invokes a function.
-   *
-   * 

Caller is responsible for closing all Tensors. - * - * @param arguments list of tensors to pass in input to the function, - * mapped by their signature name - * @return output tensors resulting from the execution of the function, - * mapped by their signature name + * Get the name of the function. */ - public Map call(Map arguments) - throws IllegalArgumentException { + public String getNativeFunctionName() { + try (PointerScope scope = new PointerScope()) { + return TF_FunctionName(nativeHandle()).getString(); + } + } - final SignatureDef signatureDef = signature.asSignatureDef(); - final Session.Runner runner = session.runner(); - signatureDef.getInputsMap().forEach((argName, t) -> { - Tensor tensor = arguments.get(argName); - if (tensor == null) { - throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + public Map> call(Scope scope, + Map> arguments) { + List> inputList = new ArrayList<>(); + + for (String inputName : signature().inputNames()) { + Operand input = arguments.get(inputName); + if (input == null) { + throw new IllegalArgumentException( + "Function " + signature().methodName() + " has parameter \"" + inputName + + "\", but no argument was passed for it."); } - runner.feed(t.getName(), tensor); - }); + inputList.add(input); + } - Map outputToNode = signatureDef.getOutputsMap(); - outputToNode.values().forEach(t -> runner.fetch(t.getName())); + scope.env().attachFunction(this); + String name = getNativeFunctionName(); - List resultTensors = runner.run(); - try { - ListIterator resultTensorIter = resultTensors.listIterator(); - Map returnMap = new HashMap(); + OperationBuilder opBuilder = scope.env().opBuilder(name, scope.makeOpName(name)); + for (Operand input : inputList) { + opBuilder.addInput(input.asOutput()); + } + opBuilder = scope.apply(opBuilder); + Operation op = opBuilder.build(); - // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { - returnMap.put(nodeName, resultTensorIter.next()); - } - return returnMap; + int numOutputs1 = op.numOutputs(); + List> outputList = new ArrayList<>(signature().outputNames().size()); + + for (int i = 0; i < numOutputs1; i++) { + outputList.add(op.output(i)); + } - } catch (Exception e) { - // Release tensors before throwing exception - for (Tensor t : resultTensors) { - t.close(); + Map> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); + + List outputNames = new ArrayList<>(signature().outputNames()); + for (int i = 0; i < outputNames.size(); i++) { + String outputName = outputNames.get(i); + + if (i > outputList.size()) { + throw new IllegalStateException("Somehow, not all required outputs were returned from the function"); } - throw e; + + Operand output = outputList.get(i); + namedOutputs.put(outputName, output); } + + return Collections.unmodifiableMap(namedOutputs); } /** - * Invokes a function with a single input and output. - * - *

Caller is responsible for closing all Tensors. + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * works for functions with a single input and output. * - * @param tensor input tensor - * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined - * in the function + * @param scope the scope to call the function in + * @param argument the argument to the call + * @return the output of the function */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { + public Operand call(Scope scope, Operand argument) { final SignatureDef signatureDef = signature.asSignatureDef(); if (signatureDef.getInputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); } - String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); + String inputName = signatureDef.getInputsMap().keySet().iterator().next(); if (signatureDef.getOutputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); } - String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); + String outputName = signatureDef.getOutputsMap().keySet().iterator().next(); - return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + Map> inputMap = new LinkedHashMap<>(); + inputMap.put(inputName, argument); + + return call(scope, inputMap).get(outputName); } /** - * Export this function as a saved model. + * Invokes a function. * - *

This method is convenient shortcut equivalent to - * {@code SavedModel.exporter(exportDir).withFunction(this).export()} + *

Caller is responsible for closing all Tensors. * - * @param exportDir directory where to export the saved model - * @throws IOException if saved model or variable state cannot be written on disk + * @param arguments list of tensors to pass in input to the function, mapped by their signature name + * @return output tensors resulting from the execution of the function, mapped by their signature name */ - public void save(String exportDir) throws IOException { - SavedModelBundle.exporter(exportDir).withFunction(this).export(); + public Map call(Map arguments) + throws IllegalArgumentException { + //TODO default device settings? Should probably execute on GPU if available + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); + Map> inputs = new LinkedHashMap<>(arguments.size()); + + for (String inputName : arguments.keySet()) { + Tensor argument = arguments.get(inputName); + inputs.put(inputName, tf.constantOf((TType) argument)); + } + Map> outputs = tf.call(this, inputs); + Map tensorOutputs = new LinkedHashMap<>(outputs.size()); + for (String outputName : outputs.keySet()) { + tensorOutputs.put(outputName, outputs.get(outputName).asTensor()); + } + return tensorOutputs; + } } /** - * Returns the session used to execute the graph when calling this function + * Invokes a function with a single input and output. * - *

In general, a user does not need to handle directly the session of a function and rely - * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to - * the session might be necessary, as it allows more running options. + *

Caller is responsible for closing all Tensors. * - * @return the function session + * @param tensor input tensor + * @return output tensor + * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function */ - public Session session() { - return session; + public Tensor call(Tensor tensor) throws IllegalArgumentException { + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); + Operand argument = tf.constantOf((TType) tensor); + Operand output = call(tf, argument); + return output.asTensor(); + } } /** - * Returns the graph of this function + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * inputs and outputs are keyed by the names set in the {@code Signature}. + * + * @param tf the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function */ - public Graph graph() { - return graph; + public Map> call(Ops tf, Map> arguments) { + return tf.call(this, arguments); + } + + /** + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * works for functions with a single input and output. + * + * @param tf the scope to call the function in + * @param argument the argument to the call + * @return the output of the function + */ + public Operand call(Ops tf, Operand argument) { + return tf.call(this, argument); + } + + /** + * Export this function as a saved model. + * + *

This method is convenient shortcut equivalent to + * {@code SavedModel.exporter(exportDir).withFunction(this).export()} + * + * @param exportDir directory where to export the saved model + * @throws IOException if saved model or variable state cannot be written on disk + */ + public void save(String exportDir) throws IOException { + SavedModelBundle.exporter(exportDir).withFunction(this).export(); } @Override public void close() { - if (ownership != Ownership.NONE) { - session.close(); - if (ownership == Ownership.GRAPH_AND_SESSION) { - graph.close(); - } - } + scope.close(); } @Override @@ -276,19 +419,56 @@ public String toString() { return signature.toString(); } - private enum Ownership { - GRAPH_AND_SESSION, SESSION_ONLY, NONE; + /** + * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA + * JIT is extremely non-obvious. + * + * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id: + * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). + */ + private void makeJit() { + try (PointerScope scope = new PointerScope()) { + byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray(); + BytePointer trueValue = new BytePointer(bytes); + + TF_Status status1 = TF_Status.newStatus(); + TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1); + status1.throwExceptionIfNotOK(); + + TF_Status status2 = TF_Status.newStatus(); + TF_FunctionSetAttrValueProto(nativeHandle(), "_noinline", trueValue, bytes.length, status2); + status2.throwExceptionIfNotOK(); + } + } + + TF_Function nativeHandle() { + if (nativeHandle.isNull()) { + throw new IllegalStateException("Function has been closed"); + } + return nativeHandle; + } + + /** + * Get the native handle of the function's gradient, so that it can be attached to a Graph. Not implemented yet. + * + * TODO implement + */ + TF_Function gradNativeHandle() { + return null; } - private final Graph graph; - private final Session session; private final Signature signature; - private final Ownership ownership; + private final TF_Function nativeHandle; + private final PointerScope scope; + + ConcreteFunction(Signature signature, Graph graph) { + this(signature, createNative(graph, signature)); + } - ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) { - this.graph = graph; - this.session = session; + ConcreteFunction(Signature signature, TF_Function nativeHandle) { this.signature = signature; - this.ownership = ownership; + scope = new PointerScope(); + this.nativeHandle = nativeHandle; + scope.attach(nativeHandle); } -} +} \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index c5d67128406..8b592880929 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -15,6 +15,7 @@ */ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextAddFunction; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetAsync; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetConfig; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy; @@ -284,6 +285,16 @@ public OperationBuilder opBuilder(String type, String name) { return new EagerOperationBuilder(this, type, name); } + @Override + public void attachFunction(ConcreteFunction function) { + checkSession(); + try (PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TFE_ContextAddFunction(nativeHandle, function.nativeHandle(), status); + status.throwExceptionIfNotOK(); + } + } + @Override public Types environmentType() { return Types.EAGER; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index a18c7fff38b..85a1b4a3355 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -37,6 +37,14 @@ enum Types { */ OperationBuilder opBuilder(String type, String name); + /** + * Attach the function to this execution environment, allowing it to be called by creating an op with the function + * name as it's {@code type}. + * + * Done automatically in the {@link org.tensorflow.op.Ops#call(ConcreteFunction, java.util.Map)} ops. + */ + void attachFunction(ConcreteFunction function); + /** * Returns true if the given operation is valid in this execution environment. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index b69fe89da0a..67cf9a765cc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -18,6 +18,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishWhile; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphCopyFunction; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphImportGraphDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphNextOperation; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphOperationByName; @@ -378,6 +379,16 @@ public GraphOperationBuilder opBuilder(String type, String name) { return new GraphOperationBuilder(this, type, name); } + @Override + public void attachFunction(ConcreteFunction function) { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), function.gradNativeHandle(), status); + status.throwExceptionIfNotOK(); + } + } + @Override public Types environmentType() { return Types.GRAPH; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 6992e5eee37..cdc86245644 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -163,11 +163,12 @@ public Exporter withFunction(ConcreteFunction function) { throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model"); } functions.put(signature.key(), function); - if (session == null) { - session = function.session(); - } else if (session != function.session()) { - throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); - } + //TODO fix saver +// if (session == null) { +// session = function.session(); +// } else if (session != function.session()) { +// throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); +// } metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); return this; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 66b4dad4132..a609a711ca5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -15,7 +15,8 @@ */ package org.tensorflow; -import java.util.HashMap; +import java.util.Collections; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; @@ -35,12 +36,25 @@ public class Signature { public static final String DEFAULT_KEY = "serving_default"; public static class TensorDescription { + + /** + * The name of the tensor's operand in the graph + */ + public final String name; + /** + * The data type of the tensor + */ public final DataType dataType; + + /** + * The shape of the tensor + */ public final Shape shape; - public TensorDescription(DataType dataType, Shape shape) { + public TensorDescription(DataType dataType, Shape shape, String name) { this.dataType = dataType; this.shape = shape; + this.name = name; } } @@ -187,29 +201,33 @@ public String toString() { } private Map buildTensorDescriptionMap(Map dataMapIn) { - Map dataTypeMap = new HashMap<>(); - dataMapIn.forEach((a, b) -> { - long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); + Map dataTypeMap = new LinkedHashMap<>(); + dataMapIn.forEach((name, info) -> { + long[] tensorDims = info.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); Shape tensorShape = Shape.of(tensorDims); - dataTypeMap.put(a, new TensorDescription(b.getDtype(), - tensorShape)); + dataTypeMap.put(name, new TensorDescription(info.getDtype(), tensorShape, info.getName())); }); - return dataTypeMap; + return Collections.unmodifiableMap(dataTypeMap); } /** - * Returns the names of the inputs in this signature mapped to their expected data type and shape - * @return + * Returns the names of the inputs in this signature mapped to their expected data type, shape, and operand name */ public Map getInputs() { - return buildTensorDescriptionMap(signatureDef.getInputsMap()); + if (inputMap == null) { + inputMap = buildTensorDescriptionMap(signatureDef.getInputsMap()); + } + return inputMap; } /** - * Returns the names of the outputs in this signature mapped to their expected data type and shape + * Returns the names of the outputs in this signature mapped to their expected data type, shape, and operand name */ public Map getOutputs() { - return buildTensorDescriptionMap(signatureDef.getOutputsMap()); + if (outputMap == null) { + outputMap = buildTensorDescriptionMap(signatureDef.getOutputsMap()); + } + return outputMap; } Signature(String key, SignatureDef signatureDef) { @@ -223,6 +241,8 @@ SignatureDef asSignatureDef() { private final String key; private final SignatureDef signatureDef; + private Map inputMap; + private Map outputMap; private static void printTensorInfo(Map tensorMap, StringBuilder strBuilder) { tensorMap.forEach((key, tensorInfo) -> { @@ -240,4 +260,4 @@ private static void printTensorInfo(Map tensorMap, StringBui strBuilder.append(")\n"); }); } -} +} \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java new file mode 100644 index 00000000000..0d021244c6b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java @@ -0,0 +1,57 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ + +package org.tensorflow.internal.c_api; + +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; + +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.annotation.Properties; + +@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public abstract class AbstractTF_Function extends Pointer { + + protected static class DeleteDeallocator extends TF_Function implements Deallocator { + + DeleteDeallocator(TF_Function s) { + super(s); + } + + @Override + public void deallocate() { + if (!isNull()) { + TF_DeleteFunction(this); + } + setNull(); + } + } + + public AbstractTF_Function(Pointer p) { + super(p); + } + + public TF_Function withDeallocator() { + return this.deallocator(new DeleteDeallocator((TF_Function) this)); + } + + /** + * Calls the deallocator, if registered, otherwise has no effect. + */ + public void delete() { + deallocate(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index 17bf9dbf79a..cb7916d0309 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -29,7 +29,6 @@ import org.bytedeco.javacpp.tools.InfoMapper; /** - * * @author Samuel Audet */ @Properties( @@ -61,17 +60,28 @@ @Platform( value = "windows", preload = { - "api-ms-win-crt-locale-l1-1-0", "api-ms-win-crt-string-l1-1-0", "api-ms-win-crt-stdio-l1-1-0", "api-ms-win-crt-math-l1-1-0", - "api-ms-win-crt-heap-l1-1-0", "api-ms-win-crt-runtime-l1-1-0", "api-ms-win-crt-convert-l1-1-0", "api-ms-win-crt-environment-l1-1-0", - "api-ms-win-crt-time-l1-1-0", "api-ms-win-crt-filesystem-l1-1-0", "api-ms-win-crt-utility-l1-1-0", "api-ms-win-crt-multibyte-l1-1-0", - "api-ms-win-core-string-l1-1-0", "api-ms-win-core-errorhandling-l1-1-0", "api-ms-win-core-timezone-l1-1-0", "api-ms-win-core-file-l1-1-0", - "api-ms-win-core-namedpipe-l1-1-0", "api-ms-win-core-handle-l1-1-0", "api-ms-win-core-file-l2-1-0", "api-ms-win-core-heap-l1-1-0", - "api-ms-win-core-libraryloader-l1-1-0", "api-ms-win-core-synch-l1-1-0", "api-ms-win-core-processthreads-l1-1-0", - "api-ms-win-core-processenvironment-l1-1-0", "api-ms-win-core-datetime-l1-1-0", "api-ms-win-core-localization-l1-2-0", - "api-ms-win-core-sysinfo-l1-1-0", "api-ms-win-core-synch-l1-2-0", "api-ms-win-core-console-l1-1-0", "api-ms-win-core-debug-l1-1-0", - "api-ms-win-core-rtlsupport-l1-1-0", "api-ms-win-core-processthreads-l1-1-1", "api-ms-win-core-file-l1-2-0", "api-ms-win-core-profile-l1-1-0", - "api-ms-win-core-memory-l1-1-0", "api-ms-win-core-util-l1-1-0", "api-ms-win-core-interlocked-l1-1-0", "ucrtbase", - "vcruntime140", "vcruntime140_1", "msvcp140", "concrt140", "vcomp140", "msvcr120", "libiomp5md", "mklml", "tensorflow_framework" + "api-ms-win-crt-locale-l1-1-0", "api-ms-win-crt-string-l1-1-0", "api-ms-win-crt-stdio-l1-1-0", + "api-ms-win-crt-math-l1-1-0", + "api-ms-win-crt-heap-l1-1-0", "api-ms-win-crt-runtime-l1-1-0", "api-ms-win-crt-convert-l1-1-0", + "api-ms-win-crt-environment-l1-1-0", + "api-ms-win-crt-time-l1-1-0", "api-ms-win-crt-filesystem-l1-1-0", "api-ms-win-crt-utility-l1-1-0", + "api-ms-win-crt-multibyte-l1-1-0", + "api-ms-win-core-string-l1-1-0", "api-ms-win-core-errorhandling-l1-1-0", + "api-ms-win-core-timezone-l1-1-0", "api-ms-win-core-file-l1-1-0", + "api-ms-win-core-namedpipe-l1-1-0", "api-ms-win-core-handle-l1-1-0", "api-ms-win-core-file-l2-1-0", + "api-ms-win-core-heap-l1-1-0", + "api-ms-win-core-libraryloader-l1-1-0", "api-ms-win-core-synch-l1-1-0", + "api-ms-win-core-processthreads-l1-1-0", + "api-ms-win-core-processenvironment-l1-1-0", "api-ms-win-core-datetime-l1-1-0", + "api-ms-win-core-localization-l1-2-0", + "api-ms-win-core-sysinfo-l1-1-0", "api-ms-win-core-synch-l1-2-0", "api-ms-win-core-console-l1-1-0", + "api-ms-win-core-debug-l1-1-0", + "api-ms-win-core-rtlsupport-l1-1-0", "api-ms-win-core-processthreads-l1-1-1", + "api-ms-win-core-file-l1-2-0", "api-ms-win-core-profile-l1-1-0", + "api-ms-win-core-memory-l1-1-0", "api-ms-win-core-util-l1-1-0", "api-ms-win-core-interlocked-l1-1-0", + "ucrtbase", + "vcruntime140", "vcruntime140_1", "msvcp140", "concrt140", "vcomp140", "msvcr120", "libiomp5md", + "mklml", "tensorflow_framework" } ), @Platform( @@ -100,132 +110,149 @@ @NoException public class tensorflow implements LoadEnabled, InfoMapper { - @Override public void init(ClassProperties properties) { - String platform = properties.getProperty("platform"); - String extension = properties.getProperty("platform.extension"); - List preloads = properties.get("platform.preload"); - List resources = properties.get("platform.preloadresource"); - List preloadpaths = properties.get("platform.preloadpath"); + @Override + public void init(ClassProperties properties) { + String platform = properties.getProperty("platform"); + String extension = properties.getProperty("platform.extension"); + List preloads = properties.get("platform.preload"); + List resources = properties.get("platform.preloadresource"); + List preloadpaths = properties.get("platform.preloadpath"); - String vcredistdir = System.getenv("VCToolsRedistDir"); - if (vcredistdir != null && vcredistdir.length() > 0) { - switch (platform) { - case "windows-x86": - preloadpaths.add(0, vcredistdir + "\\x86\\Microsoft.VC142.CRT"); - preloadpaths.add(1, vcredistdir + "\\x86\\Microsoft.VC142.OpenMP"); - preloadpaths.add(2, vcredistdir + "\\x86\\Microsoft.VC141.CRT"); - preloadpaths.add(3, vcredistdir + "\\x86\\Microsoft.VC141.OpenMP"); - break; - case "windows-x86_64": - preloadpaths.add(0, vcredistdir + "\\x64\\Microsoft.VC142.CRT"); - preloadpaths.add(1, vcredistdir + "\\x64\\Microsoft.VC142.OpenMP"); - preloadpaths.add(2, vcredistdir + "\\x64\\Microsoft.VC141.CRT"); - preloadpaths.add(3, vcredistdir + "\\x64\\Microsoft.VC141.OpenMP"); - break; - default: - // not Windows - } - } - - // Only apply this at load time - if (!Loader.isLoadLibraries()) { - return; - } + String vcredistdir = System.getenv("VCToolsRedistDir"); + if (vcredistdir != null && vcredistdir.length() > 0) { + switch (platform) { + case "windows-x86": + preloadpaths.add(0, vcredistdir + "\\x86\\Microsoft.VC142.CRT"); + preloadpaths.add(1, vcredistdir + "\\x86\\Microsoft.VC142.OpenMP"); + preloadpaths.add(2, vcredistdir + "\\x86\\Microsoft.VC141.CRT"); + preloadpaths.add(3, vcredistdir + "\\x86\\Microsoft.VC141.OpenMP"); + break; + case "windows-x86_64": + preloadpaths.add(0, vcredistdir + "\\x64\\Microsoft.VC142.CRT"); + preloadpaths.add(1, vcredistdir + "\\x64\\Microsoft.VC142.OpenMP"); + preloadpaths.add(2, vcredistdir + "\\x64\\Microsoft.VC141.CRT"); + preloadpaths.add(3, vcredistdir + "\\x64\\Microsoft.VC141.OpenMP"); + break; + default: + // not Windows + } + } - // Let users enable loading of the full version of MKL - String load = System.getProperty("org.bytedeco.openblas.load", - System.getProperty("org.bytedeco.mklml.load", "")).toLowerCase(); + // Only apply this at load time + if (!Loader.isLoadLibraries()) { + return; + } - int i = 0; - if (load.equals("mkl") || load.equals("mkl_rt")) { - String[] libs = {"iomp5", "libiomp5md", "mkl_core", "mkl_avx", "mkl_avx2", "mkl_avx512", "mkl_avx512_mic", - "mkl_def", "mkl_mc", "mkl_mc3", "mkl_intel_lp64", "mkl_intel_thread", "mkl_gnu_thread", "mkl_rt"}; - for (i = 0; i < libs.length; i++) { - preloads.add(i, libs[i] + "#" + libs[i]); - } - load = "mkl_rt"; - resources.add("/org/bytedeco/mkl/"); - } + // Let users enable loading of the full version of MKL + String load = System.getProperty("org.bytedeco.openblas.load", + System.getProperty("org.bytedeco.mklml.load", "")).toLowerCase(); - if (load.length() > 0) { - if (platform.startsWith("linux")) { - preloads.add(i, load + "#mklml_intel"); - } else if (platform.startsWith("macosx")) { - preloads.add(i, load + "#mklml"); - } else if (platform.startsWith("windows")) { - preloads.add(i, load + "#mklml"); - } - } + int i = 0; + if (load.equals("mkl") || load.equals("mkl_rt")) { + String[] libs = {"iomp5", "libiomp5md", "mkl_core", "mkl_avx", "mkl_avx2", "mkl_avx512", "mkl_avx512_mic", + "mkl_def", "mkl_mc", "mkl_mc3", "mkl_intel_lp64", "mkl_intel_thread", "mkl_gnu_thread", "mkl_rt"}; + for (i = 0; i < libs.length; i++) { + preloads.add(i, libs[i] + "#" + libs[i]); + } + load = "mkl_rt"; + resources.add("/org/bytedeco/mkl/"); + } - // Only apply this at load time since we don't want to copy the CUDA libraries here - if (!Loader.isLoadLibraries() || extension == null || !extension.endsWith("-gpu")) { - return; - } - String[] libs = {"cudart", "cublasLt", "cublas", "cufft", "curand", "cusolver", "cusparse", "cudnn", "nccl", "nvrtc", "myelin", "nvinfer", - "cudnn_ops_infer", "cudnn_ops_train", "cudnn_adv_infer", "cudnn_adv_train", "cudnn_cnn_infer", "cudnn_cnn_train"}; - for (String lib : libs) { - if (platform.startsWith("linux")) { - lib += lib.startsWith("cudnn") ? "@.8" - : lib.equals("nccl") ? "@.2" - : lib.equals("myelin") ? "@.1" - : lib.equals("nvinfer") ? "@.7" - : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "@.10" - : lib.equals("cudart") ? "@.11.0" - : lib.equals("nvrtc") ? "@.11.0" - : "@.11"; - } else if (platform.startsWith("windows")) { - lib += lib.startsWith("cudnn") ? "64_8" - : lib.equals("nccl") ? "64_2" - : lib.equals("myelin") ? "64_1" - : lib.equals("nvinfer") ? "64_7" - : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "64_10" - : lib.equals("cudart") ? "64_110" - : lib.equals("nvrtc") ? "64_110_0" - : "64_11"; - } else { - continue; // no CUDA - } - if (!preloads.contains(lib)) { - preloads.add(i++, lib); - } - } - if (i > 0) { - resources.add("/org/bytedeco/cuda/"); - resources.add("/org/bytedeco/tensorrt/"); - } + if (load.length() > 0) { + if (platform.startsWith("linux")) { + preloads.add(i, load + "#mklml_intel"); + } else if (platform.startsWith("macosx")) { + preloads.add(i, load + "#mklml"); + } else if (platform.startsWith("windows")) { + preloads.add(i, load + "#mklml"); + } } - public void map(InfoMap infoMap) { - infoMap.put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations()) - .put(new Info("TF_Buffer::data").javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) - .put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status")) - .put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) - .put(new Info("TF_Tensor").pointerTypes("TF_Tensor").base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) - .put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session")) - .put(new Info("TF_SessionOptions").pointerTypes("TF_SessionOptions").base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) - .put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph")) - .put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef Graph graph();")) - .put(new Info("TF_Graph::refiner").javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();")) - .put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions").base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) - .put(new Info("TF_Operation", "TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell", - "TFE_MonitoringCounter0", "TFE_MonitoringCounter1", "TFE_MonitoringCounter2", - "TFE_MonitoringIntGaugeCell", "TFE_MonitoringStringGaugeCell", "TFE_MonitoringBoolGaugeCell", - "TFE_MonitoringIntGauge0", "TFE_MonitoringIntGauge1", "TFE_MonitoringIntGauge2", - "TFE_MonitoringStringGauge0", "TFE_MonitoringStringGauge1", "TFE_MonitoringStringGauge2", - "TFE_MonitoringBoolGauge0", "TFE_MonitoringBoolGauge1", "TFE_MonitoringBoolGauge2", - "TFE_MonitoringSampler0", "TFE_MonitoringSampler1", "TFE_MonitoringSampler2").purify()) - .put(new Info("TF_Operation::node").javaText("public native @MemberGetter @ByRef Node node();")) - .put(new Info("TFE_MonitoringCounterCell::cell").javaText("public native @MemberGetter @ByRef CounterCell cell();")) - .put(new Info("TFE_MonitoringSamplerCell::cell").javaText("public native @MemberGetter @ByRef SamplerCell cell();")) - .put(new Info("TFE_MonitoringIntGaugeCell::cell").javaText("public native @MemberGetter @ByRef IntGaugeCell cell();")) - .put(new Info("TFE_MonitoringStringGaugeCell::cell").javaText("public native @MemberGetter @ByRef StringGaugeCell cell();")) - .put(new Info("TFE_MonitoringBoolGaugeCell::cell").javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();")) - .put(new Info("TFE_Context").pointerTypes("TFE_Context").base("org.tensorflow.internal.c_api.AbstractTFE_Context")) - .put(new Info("TFE_ContextOptions").pointerTypes("TFE_ContextOptions").base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) - .put(new Info("TFE_Context::context").javaText("@MemberGetter public native @ByRef EagerContext context();")) - .put(new Info("TFE_Op").pointerTypes("TFE_Op").base("org.tensorflow.internal.c_api.AbstractTFE_Op")) - .put(new Info("TFE_Op::operation").javaText("@MemberGetter public native @ByRef EagerOperation operation();")) - .put(new Info("TFE_TensorHandle").pointerTypes("TFE_TensorHandle").base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) - .put(new Info("TF_ShapeInferenceContextDimValueKnown", "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip()); + // Only apply this at load time since we don't want to copy the CUDA libraries here + if (!Loader.isLoadLibraries() || extension == null || !extension.endsWith("-gpu")) { + return; } + String[] libs = {"cudart", "cublasLt", "cublas", "cufft", "curand", "cusolver", "cusparse", "cudnn", "nccl", + "nvrtc", "myelin", "nvinfer", + "cudnn_ops_infer", "cudnn_ops_train", "cudnn_adv_infer", "cudnn_adv_train", "cudnn_cnn_infer", + "cudnn_cnn_train"}; + for (String lib : libs) { + if (platform.startsWith("linux")) { + lib += lib.startsWith("cudnn") ? "@.8" + : lib.equals("nccl") ? "@.2" + : lib.equals("myelin") ? "@.1" + : lib.equals("nvinfer") ? "@.7" + : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "@.10" + : lib.equals("cudart") ? "@.11.0" + : lib.equals("nvrtc") ? "@.11.0" + : "@.11"; + } else if (platform.startsWith("windows")) { + lib += lib.startsWith("cudnn") ? "64_8" + : lib.equals("nccl") ? "64_2" + : lib.equals("myelin") ? "64_1" + : lib.equals("nvinfer") ? "64_7" + : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "64_10" + : lib.equals("cudart") ? "64_110" + : lib.equals("nvrtc") ? "64_110_0" + : "64_11"; + } else { + continue; // no CUDA + } + if (!preloads.contains(lib)) { + preloads.add(i++, lib); + } + } + if (i > 0) { + resources.add("/org/bytedeco/cuda/"); + resources.add("/org/bytedeco/tensorrt/"); + } + } + + public void map(InfoMap infoMap) { + infoMap.put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations()) + .put(new Info("TF_Buffer::data") + .javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) + .put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status")) + .put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) + .put(new Info("TF_Tensor").pointerTypes("TF_Tensor").base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) + .put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session")) + .put(new Info("TF_SessionOptions").pointerTypes("TF_SessionOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) + .put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph")) + .put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef Graph graph();")) + .put(new Info("TF_Graph::refiner").javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();")) + .put(new Info("TF_Function").pointerTypes("TF_Function") + .base("org.tensorflow.internal.c_api.AbstractTF_Function")) + .put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) + .put(new Info("TF_Operation", "TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell", + "TFE_MonitoringCounter0", "TFE_MonitoringCounter1", "TFE_MonitoringCounter2", + "TFE_MonitoringIntGaugeCell", "TFE_MonitoringStringGaugeCell", "TFE_MonitoringBoolGaugeCell", + "TFE_MonitoringIntGauge0", "TFE_MonitoringIntGauge1", "TFE_MonitoringIntGauge2", + "TFE_MonitoringStringGauge0", "TFE_MonitoringStringGauge1", "TFE_MonitoringStringGauge2", + "TFE_MonitoringBoolGauge0", "TFE_MonitoringBoolGauge1", "TFE_MonitoringBoolGauge2", + "TFE_MonitoringSampler0", "TFE_MonitoringSampler1", "TFE_MonitoringSampler2").purify()) + .put(new Info("TF_Operation::node").javaText("public native @MemberGetter @ByRef Node node();")) + .put(new Info("TFE_MonitoringCounterCell::cell") + .javaText("public native @MemberGetter @ByRef CounterCell cell();")) + .put(new Info("TFE_MonitoringSamplerCell::cell") + .javaText("public native @MemberGetter @ByRef SamplerCell cell();")) + .put(new Info("TFE_MonitoringIntGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef IntGaugeCell cell();")) + .put(new Info("TFE_MonitoringStringGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef StringGaugeCell cell();")) + .put(new Info("TFE_MonitoringBoolGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();")) + .put(new Info("TFE_Context").pointerTypes("TFE_Context") + .base("org.tensorflow.internal.c_api.AbstractTFE_Context")) + .put(new Info("TFE_ContextOptions").pointerTypes("TFE_ContextOptions") + .base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) + .put(new Info("TFE_Context::context").javaText("@MemberGetter public native @ByRef EagerContext context();")) + .put(new Info("TFE_Op").pointerTypes("TFE_Op").base("org.tensorflow.internal.c_api.AbstractTFE_Op")) + .put(new Info("TFE_Op::operation").javaText("@MemberGetter public native @ByRef EagerOperation operation();")) + .put(new Info("TFE_TensorHandle").pointerTypes("TFE_TensorHandle") + .base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) + .put(new Info("TF_ShapeInferenceContextDimValueKnown", + "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip()); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java new file mode 100644 index 00000000000..b38b7e0bbbb --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java @@ -0,0 +1,45 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import java.util.Map; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.Operand; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; + +/** + * Ops for calling {@link ConcreteFunction}. Even though the C API docs say the name of the Op needs to be the name of + * the function, they mean the type. + */ +@Operator(name = "call") +public abstract class Function { + + @Endpoint + public static Map> call(Scope scope, ConcreteFunction function, + Map> arguments) { + return function.call(scope, arguments); + } + + @Endpoint + public static Operand call(Scope scope, ConcreteFunction function, + Operand argument) { + return function.call(scope, argument); + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index b2b2c34e223..7f7492691ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -15,7 +15,6 @@ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; @@ -80,42 +79,4 @@ public void chainFunctions() { assertEquals(6.0f, ((TFloat32)f2.call(f1.call(x))).getFloat()); } } - - @Test - public void closingFunctionReleaseAllResourcesItOwns() { - Graph g; - Session s; - try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive)) { - g = f.graph(); - s = f.session(); - } - assertThrows(IllegalStateException.class, () -> s.run("Add")); - assertThrows(IllegalStateException.class, () -> g.toGraphDef()); - } - - @Test - public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - Session s; - try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { - s = f.session(); - } - assertThrows(IllegalStateException.class, () -> s.run(Init.DEFAULT_NAME)); - g.toGraphDef(); // check that graph is still valid - } - } - - @Test - public void closingFunctionCreatedFromSessionDoesNotReleaseResources() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - try (Session s = new Session(g)) { - try (ConcreteFunction f = ConcreteFunction.create(signature, s)) { - } - s.run(Init.DEFAULT_NAME); // check that session is still valid - } - g.toGraphDef(); // check that graph is still valid - } - } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 032c835c0cc..2dae1077874 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -104,7 +104,7 @@ public void exportFunctionWithVariables() throws IOException { Shape xyShape = Shape.of(2, 3L); try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { // Init variable state by running the Init operation directly - f.session().run(Init.DEFAULT_NAME); + //TODO f.session().run(Init.DEFAULT_NAME); // Call the graph and remember the result of computation for later try (TFloat32 xTensor = TFloat32.tensorOf(xValue); @@ -178,7 +178,7 @@ public void exportMultipleFunctions() throws IOException { try (Session s = new Session(g); ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { - f1.session().run(Init.DEFAULT_NAME); + //TODO f1.session().run(Init.DEFAULT_NAME); try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); TFloat32 t = (TFloat32)f1.call(x)) { reducedSum = t.getFloat(); @@ -221,7 +221,7 @@ public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOExcept Signature f2Signature = buildIdentityGraph(tf, "identity"); try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g); ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) { - f1.session().run(Init.DEFAULT_NAME); + //TODO f1.session().run(Init.DEFAULT_NAME); try { SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) @@ -245,7 +245,7 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti try (Session s = new Session(g); ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { - f1.session().run(Init.DEFAULT_NAME); + //TODO f1.session().run(Init.DEFAULT_NAME); try { SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java new file mode 100644 index 00000000000..ea06933a7b9 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java @@ -0,0 +1,70 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ConcreteFunction; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Signature; +import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Add; +import org.tensorflow.types.TFloat32; + +/** + * Tests for GraphFunction and it's ops + */ +public class FunctionTest { + + private static Signature plusFive(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.class); + Add output = tf.math.add(input, tf.constant(5.0f)); + Init init = tf.init(); // for native resource management tests + return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); + } + + @Test + public void testConcreteFunctionEager() { + try (EagerSession sess = EagerSession.create(); + ConcreteFunction function = ConcreteFunction.create(FunctionTest::plusFive)) { + Ops tf = Ops.create(sess); + Operand a = tf.constant(10f); + Operand result = (Operand) function.call(tf, a); + try (TFloat32 t = result.asTensor()) { + assertEquals(15f, t.getFloat()); + } + } + } + + @Test + public void testConcreteFunctionGraph() { + try (Graph graph = new Graph(); + ConcreteFunction function = ConcreteFunction.create(FunctionTest::plusFive)) { + Ops tf = Ops.create(graph); + Operand a = tf.constant(10f); + Operand result = (Operand) function.call(tf, a); + try (Session sess = new Session(graph); + TFloat32 t = (TFloat32) sess.runner().fetch(result).run().get(0)) { + assertEquals(15f, t.getFloat()); + } + } + } +} From 278d7ccba441becf425b1aad4f1550b9d27c6357 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 16:28:32 -0800 Subject: [PATCH 02/34] Allow body constants Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/ConcreteFunction.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 95ac8360db3..07560fa1740 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -99,7 +99,8 @@ private static TF_Function createNative(Graph graph, Signature signature) { .map((x) -> graph.outputOrError(x.name)) .collect(Collectors.toList()); - List ops = new ArrayList<>(graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs))); + List ops = new ArrayList<>( + graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs), true)); PointerPointer operations = new PointerPointer<>(ops.size()); for (int i = 0; i < ops.size(); i++) { @@ -123,8 +124,7 @@ private static TF_Function createNative(Graph graph, Signature signature) { ); status.throwExceptionIfNotOK(); - - return handle.withDeallocator(); + return handle; } } @@ -469,6 +469,6 @@ TF_Function gradNativeHandle() { this.signature = signature; scope = new PointerScope(); this.nativeHandle = nativeHandle; - scope.attach(nativeHandle); + scope.attach(nativeHandle.withDeallocator()); } } \ No newline at end of file From c2b0b60457a965b135ed29091c4f8863e7cd44aa Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 17:00:13 -0800 Subject: [PATCH 03/34] Fix body forbids Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/ConcreteFunction.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 07560fa1740..f43b6f2e990 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -39,6 +39,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.proto.framework.AttrValue; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.types.family.TType; @@ -100,7 +101,8 @@ private static TF_Function createNative(Graph graph, Signature signature) { .collect(Collectors.toList()); List ops = new ArrayList<>( - graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs), true)); + graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs), + null, Collections.singleton(Placeholder.OP_NAME))); PointerPointer operations = new PointerPointer<>(ops.size()); for (int i = 0; i < ops.size(); i++) { From dec04e6822a2342135fad540aa1e7f900363ec67 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 17:12:18 -0800 Subject: [PATCH 04/34] Use default eager session for tensor calls Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index f43b6f2e990..ac032b96cc9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -329,7 +329,7 @@ public Operand call(Scope scope, Operand argument) { } /** - * Invokes a function. + * Invokes a function using the default eager session. * *

Caller is responsible for closing all Tensors. * @@ -338,26 +338,24 @@ public Operand call(Scope scope, Operand argument) { */ public Map call(Map arguments) throws IllegalArgumentException { - //TODO default device settings? Should probably execute on GPU if available - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - Map> inputs = new LinkedHashMap<>(arguments.size()); + //FIXME need to manage input/output operand lifetimes + Ops tf = Ops.create(); + Map> inputs = new LinkedHashMap<>(arguments.size()); - for (String inputName : arguments.keySet()) { - Tensor argument = arguments.get(inputName); - inputs.put(inputName, tf.constantOf((TType) argument)); - } - Map> outputs = tf.call(this, inputs); - Map tensorOutputs = new LinkedHashMap<>(outputs.size()); - for (String outputName : outputs.keySet()) { - tensorOutputs.put(outputName, outputs.get(outputName).asTensor()); - } - return tensorOutputs; + for (String inputName : arguments.keySet()) { + Tensor argument = arguments.get(inputName); + inputs.put(inputName, tf.constantOf((TType) argument)); + } + Map> outputs = tf.call(this, inputs); + Map tensorOutputs = new LinkedHashMap<>(outputs.size()); + for (String outputName : outputs.keySet()) { + tensorOutputs.put(outputName, outputs.get(outputName).asTensor()); } + return tensorOutputs; } /** - * Invokes a function with a single input and output. + * Invokes a function with a single input and output using the default eager session. * *

Caller is responsible for closing all Tensors. * From 047243bfb5660fb3cbf13e0002f39d96fb0a16e7 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 17:24:23 -0800 Subject: [PATCH 05/34] Use default eager for single tensor call too Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/ConcreteFunction.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index ac032b96cc9..73e81f646c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -364,12 +364,10 @@ public Map call(Map arguments) * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function */ public Tensor call(Tensor tensor) throws IllegalArgumentException { - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - Operand argument = tf.constantOf((TType) tensor); - Operand output = call(tf, argument); - return output.asTensor(); - } + Ops tf = Ops.create(); + Operand argument = tf.constantOf((TType) tensor); + Operand output = call(tf, argument); + return output.asTensor(); } /** From 696ef67811737ac752d35689149c9d5b4df5e5ea Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 19:23:25 -0800 Subject: [PATCH 06/34] Get functions from graph Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 48 +++++++++++++ .../src/main/java/org/tensorflow/Graph.java | 72 ++++++++++++++++++- .../main/java/org/tensorflow/Signature.java | 36 +++++++++- .../org/tensorflow/ConcreteFunctionTest.java | 20 +++++- 4 files changed, 172 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 73e81f646c5..ea879f7942c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -17,8 +17,10 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionToFunctionDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; +import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -33,6 +35,7 @@ import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.tensorflow.Graph.Reference; +import org.tensorflow.internal.c_api.TF_Buffer; import org.tensorflow.internal.c_api.TF_Function; import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_Output; @@ -41,7 +44,11 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.core.Placeholder; import org.tensorflow.proto.framework.AttrValue; +import org.tensorflow.proto.framework.FunctionDef; +import org.tensorflow.proto.framework.OpDef.ArgDef; import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; import org.tensorflow.types.family.TType; /** @@ -469,4 +476,45 @@ TF_Function gradNativeHandle() { this.nativeHandle = nativeHandle; scope.attach(nativeHandle.withDeallocator()); } + + /** + * Detects the signature from the handle + */ + static ConcreteFunction fromNativeHandle(TF_Function function) { + TF_Buffer funcDefBuffer = TF_Buffer.newBuffer(); + TF_Status status2 = TF_Status.newStatus(); + TF_FunctionToFunctionDef(function, funcDefBuffer, status2); + status2.throwExceptionIfNotOK(); + FunctionDef funcDef = null; + try { + funcDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException("Failed to parse FunctionDef proto", e); + } + + Signature.Builder builder = Signature.builder().methodName(funcDef.getSignature().getName()) + .key(TF_FunctionName(function).getString()); + + for (ArgDef input : funcDef.getSignature().getInputArgList()) { + TensorInfo info = TensorInfo.newBuilder() + .setDtype(input.getType()) + .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) + .setName(input.getName()) + .build(); + + builder.input(input.getName(), info); + } + + for (ArgDef outputDef : funcDef.getSignature().getOutputArgList()) { + TensorInfo info = TensorInfo.newBuilder() + .setDtype(outputDef.getType()) + .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) + .setName(outputDef.getName()) + .build(); + + builder.output(outputDef.getName(), info); + } + + return new ConcreteFunction(builder.build(), function); + } } \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 67cf9a765cc..0b3f4dd0be6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -16,11 +16,15 @@ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishWhile; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphCopyFunction; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphGetFunctions; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphImportGraphDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphNextOperation; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphNumFunctions; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphOperationByName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToGraphDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_ImportGraphDefOptionsSetPrefix; @@ -40,10 +44,12 @@ import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.SizeTPointer; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.c_api.TF_Buffer; +import org.tensorflow.internal.c_api.TF_Function; import org.tensorflow.internal.c_api.TF_Graph; import org.tensorflow.internal.c_api.TF_ImportGraphDefOptions; import org.tensorflow.internal.c_api.TF_Operation; @@ -382,13 +388,77 @@ public GraphOperationBuilder opBuilder(String type, String name) { @Override public void attachFunction(ConcreteFunction function) { try (Reference ref = ref(); - PointerScope scope = new PointerScope()) { + PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), function.gradNativeHandle(), status); status.throwExceptionIfNotOK(); } } + /** + * Get the function attached to the graph with the given native name. Returns {@code null} if none found. + * + * @param key the name of the native function. Note that this may include an argument hash. + * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct name + */ + public synchronized ConcreteFunction getFunction(String key) { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + + int numFunctions = TF_GraphNumFunctions(ref.nativeHandle()); + + PointerPointer output = new PointerPointer<>(numFunctions); + + TF_GraphGetFunctions(ref.nativeHandle(), output, numFunctions, status); + status.throwExceptionIfNotOK(); + + ConcreteFunction func = null; + + for (int i = 0; i < numFunctions; i++) { + TF_Function function = output.get(TF_Function.class, i); + + String functionName = TF_FunctionName(function).getString(); + + if (functionName.equals(key) && func == null) { + func = ConcreteFunction.fromNativeHandle(function); + } else { + TF_DeleteFunction(function); + } + } + + return func; + } + } + + /** + * Get the functions attached to the graph. + * + * @return all functions attached to this graph. + */ + public synchronized List getFunctions() { + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + + int numFunctions = TF_GraphNumFunctions(ref.nativeHandle()); + + PointerPointer output = new PointerPointer<>(numFunctions); + + TF_GraphGetFunctions(ref.nativeHandle(), output, numFunctions, status); + status.throwExceptionIfNotOK(); + + List funcs = new ArrayList<>(numFunctions); + for (int i = 0; i < numFunctions; i++) { + TF_Function function = output.get(TF_Function.class, i); + + funcs.add(ConcreteFunction.fromNativeHandle(function)); + } + + return funcs; + } + } + @Override public Types environmentType() { return Types.GRAPH; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index a609a711ca5..658dbff205d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -96,6 +96,22 @@ public Builder input(String inputName, Operand input) { return this; } + /** + * Register a tensor as an input of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor info + * @return this builder + * @throws IllegalArgumentException if {@code inputName} is already mapped to another input + */ + Builder input(String inputName, TensorInfo input) { + if (signatureBuilder.containsInputs(inputName)) { + throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input"); + } + signatureBuilder.putInputs(inputName, input); + return this; + } + /** * Register a tensor as an output of the function. * @@ -113,8 +129,24 @@ public Builder output(String outputName, Operand output) { } /** - * Provide extensible name information enabling third-party users to mark a signature as - * supporting a particular method + * Register a tensor as an output of the function. + * + * @param outputName user-friendly name for this output tensor + * @param output output tensor + * @return this builder + * @throws IllegalArgumentException if {@code outputName} is already mapped to another output + */ + Builder output(String outputName, TensorInfo output) { + if (signatureBuilder.containsOutputs(outputName)) { + throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output"); + } + signatureBuilder.putOutputs(outputName, output); + return this; + } + + /** + * Provide extensible name information enabling third-party users to mark a signature as supporting a particular + * method * * @param methodName method name or null for none (default) * @return this builder diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 7f7492691ec..e8b08e8f9e9 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -15,6 +15,7 @@ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; @@ -76,7 +77,24 @@ public void chainFunctions() { try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(6.0f, ((TFloat32)f2.call(f1.call(x))).getFloat()); + assertEquals(6.0f, ((TFloat32) f2.call(f1.call(x))).getFloat()); + } + } + + @Test + public void getGraphFunctions() { + try (ConcreteFunction function = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + Graph g = new Graph()) { + Ops tf = Ops.create(g); + tf.call(function, tf.constant(3f)); + + ConcreteFunction attached = g.getFunction(function.getNativeFunctionName()); + assertNotNull(attached); + + try (TFloat32 x = TFloat32.scalarOf(10f); + TFloat32 y = (TFloat32) attached.call(x)) { + assertEquals(15f, y.getFloat()); + } } } } From 71b8fab3f6ca129cd84e56e8b4096a58633e6d74 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 5 Mar 2021 14:03:22 -0800 Subject: [PATCH 07/34] Start of saver support Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 1 - .../src/main/java/org/tensorflow/Graph.java | 12 +- .../java/org/tensorflow/SavedModelBundle.java | 129 +++++++++--------- .../src/main/java/org/tensorflow/Session.java | 8 +- .../org/tensorflow/SavedModelBundleTest.java | 2 + 5 files changed, 82 insertions(+), 70 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index ea879f7942c..6e34f4ce13f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -170,7 +170,6 @@ private static TF_Function createNative(Graph graph, Signature signature) { * @return the new function */ public static ConcreteFunction create(Function functionBuilder) { - //TODO make sure this works oki with graph closing try (Graph graph = new Graph()) { Ops tf = Ops.create(graph); Signature signature = functionBuilder.apply(tf); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 0b3f4dd0be6..2df714caa16 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -1158,12 +1158,20 @@ private static SaverDef addVariableSaver(Graph graph) { } } + Placeholder saveFilename = tf.withName("filename").placeholder(TString.class); + + if (varNames.isEmpty()) { + return SaverDef.newBuilder() + .setFilenameTensorName(saveFilename.op().name()) + .setSaveTensorName(tf.withName("empty_save").noOp().op().name()) + .setRestoreOpName(tf.withName("restore_all").noOp().op().name()) + .build(); + } + // FIXME Need an easier way to initialize an NdArray from a list String[] tmp = new String[varNames.size()]; Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); Operand varSlices = tf.zerosLike(varNamesTensor); - - Placeholder saveFilename = tf.withName("filename").placeholder(TString.class); Save saveVariables = tf.train.save(saveFilename, varNamesTensor, varSlices, varOutputs); Identity id = tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables)) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index cdc86245644..004c4b43d6a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -51,19 +51,22 @@ * SavedModelBundle represents a model loaded from storage. * *

The model consists of a description of the computation (a {@link Graph}), a {@link Session} - * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, - * and a description of the model as a MetaGraphDef + * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, and a description + * of the model as a MetaGraphDef * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { public static final String DEFAULT_TAG = "serve"; - /** Options for loading a SavedModel. */ + /** + * Options for loading a SavedModel. + */ public static final class Loader { - /** Load a SavedModelBundle with the configured options. */ + /** + * Load a SavedModelBundle with the configured options. + */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); } @@ -71,9 +74,8 @@ public SavedModelBundle load() { /** * Sets options to use when executing model initialization operations. * - * @param options A RunOptions - * protocol buffer. + * @param options A RunOptions + * protocol buffer. * @return this object */ public Loader withRunOptions(RunOptions options) { @@ -84,9 +86,8 @@ public Loader withRunOptions(RunOptions options) { /** * Set configuration of the Session object created when loading the model. * - * @param configProto A ConfigProto - * protocol buffer. + * @param configProto A ConfigProto + * protocol buffer. * @return this object */ public Loader withConfigProto(ConfigProto configProto) { @@ -114,12 +115,14 @@ private Loader(String exportDir) { } private String exportDir = null; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private ConfigProto configProto = null; private RunOptions runOptions = null; } - /** Options for exporting a SavedModel. */ + /** + * Options for exporting a SavedModel. + */ public static final class Exporter { /** @@ -144,9 +147,9 @@ public Exporter withTags(String... tags) { * names to a graph) and a valid session to a graph to be saved in the model. * *

Note:Eventually, TensorFlow for Java will support the export of functions objects like - * the Python API does but right now, only session-centric models are supported (i.e. models that - * has a single main graph and one or more signatures). These models are compatible with those - * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. + * the Python API does but right now, only session-centric models are supported (i.e. models that has a single main + * graph and one or more signatures). These models are compatible with those exported by TensorFlow 1.x or by + * TensorFlow 2.x estimators. * *
Therefore, all functions exported in a model should share the same session at the moment * or an exception will be thrown.
@@ -154,8 +157,8 @@ public Exporter withTags(String... tags) { * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object * @throws IllegalArgumentException if a function with the same name has already been added to the model - * @throws UnsupportedOperationException if this function does not share the same session with the other - * functions added to this model + * @throws UnsupportedOperationException if this function does not share the same session with the other functions + * added to this model */ public Exporter withFunction(ConcreteFunction function) { Signature signature = function.signature(); @@ -163,12 +166,6 @@ public Exporter withFunction(ConcreteFunction function) { throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model"); } functions.put(signature.key(), function); - //TODO fix saver -// if (session == null) { -// session = function.session(); -// } else if (session != function.session()) { -// throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); -// } metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); return this; } @@ -179,33 +176,39 @@ public Exporter withFunction(ConcreteFunction function) { * @throws IOException if saved model or variable state cannot be written on disk */ public void export() throws IOException { - if (functions.isEmpty() || session == null) { + if (functions.isEmpty()) { throw new IllegalStateException("Model should contain at least one valid function"); } - Graph graph = session.graph(); - - // It is imperative to retrieve the graphDef after the saverDef, as the former might add - // new ops to the graph for saving and restoring the variables. - SaverDef saverDef = graph.saverDef(); - - MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder - .setSaverDef(saverDef) - .setGraphDef(graph.toGraphDef()) - .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); - functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); - - // Make sure saved model directories exist - Path variableDir = Paths.get(exportDir, "variables"); - variableDir.toFile().mkdirs(); - - // Save the variables state - session.save(variableDir.resolve("variables").toString()); - - // Save the graph - SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); - try (OutputStream file = - new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { - savedModelDef.writeTo(file); + try (Graph graph = new Graph(); + Session session = new Session(graph)) { + + functions.values().forEach(graph::attachFunction); + + session.runInit(); + + // It is imperative to retrieve the graphDef after the saverDef, as the former might add + // new ops to the graph for saving and restoring the variables. + SaverDef saverDef = graph.saverDef(); + + MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder + .setSaverDef(saverDef) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); + functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); + + // Make sure saved model directories exist + Path variableDir = Paths.get(exportDir, "variables"); + variableDir.toFile().mkdirs(); + + // Save the variables state + session.save(variableDir.resolve("variables").toString()); + + // Save the graph + SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); + try (OutputStream file = + new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { + savedModelDef.writeTo(file); + } } } @@ -214,16 +217,14 @@ public void export() throws IOException { } private final String exportDir; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); private final Map functions = new LinkedHashMap<>(); - private Session session; } /** - * Load a saved model from an export directory. The model that is being loaded should be created - * using the Saved Model - * API. + * Load a saved model from an export directory. The model that is being loaded should be created using the Saved Model API. * *

This method is a shorthand for: * @@ -268,15 +269,16 @@ public static Exporter exporter(String exportDir) { } /** - * Returns the MetaGraphDef + * Returns the MetaGraphDef * protocol buffer associated with the saved model. */ public MetaGraphDef metaGraphDef() { return metaGraphDef; } - /** Returns the graph that describes the computation performed by the model. */ + /** + * Returns the graph that describes the computation performed by the model. + */ public Graph graph() { return graph; } @@ -307,8 +309,7 @@ public List signatures() { * * @param signatureKey name of the {@code SignatureDef} in the saved model. * @return object that can be used to make calls to a function - * @throws IllegalArgumentException if {@code signatureKey} is not found in this - * saved model. + * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ public ConcreteFunction function(String signatureKey) { ConcreteFunction function = functions.get(signatureKey); @@ -349,8 +350,7 @@ public Map call(Map arguments) { } /** - * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model - * bundle. + * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model bundle. */ @Override public void close() { @@ -363,7 +363,8 @@ public void close() { private final MetaGraphDef metaGraphDef; private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, + Map functions) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; @@ -371,8 +372,8 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef } /** - * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session - * object, plus the MetaGraphDef. + * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the + * MetaGraphDef. * *

Invoked from the native load method. Takes ownership of the handles. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 58fb62b5fee..6890a959f82 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -511,9 +511,11 @@ public void run(Op op) { *

This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. */ public void runInit() { - Runner runner = runner(); - graph.initializers().forEach(runner::addTarget); - runner.run(); + if (!graph.initializers().isEmpty()) { + Runner runner = runner(); + graph.initializers().forEach(runner::addTarget); + runner.run(); + } } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 2dae1077874..9cd4b3d4c87 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -29,6 +29,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import org.junit.Ignore; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -213,6 +214,7 @@ public void exportMultipleFunctions() throws IOException { } @Test + @Ignore // this is supported now public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); try (Graph g = new Graph()) { From 36971221733be753744bdebc2b69b75afe4c4128 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 5 Mar 2021 19:47:05 -0800 Subject: [PATCH 08/34] Update loading, detect statefulness, use PartitionedCall Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 291 ++++++++++++------ .../org/tensorflow/EagerOperationBuilder.java | 16 +- .../src/main/java/org/tensorflow/Graph.java | 2 +- .../org/tensorflow/GraphOperationBuilder.java | 17 + .../java/org/tensorflow/OperationBuilder.java | 13 +- .../java/org/tensorflow/SavedModelBundle.java | 124 +++++++- .../main/java/org/tensorflow/TensorFlow.java | 45 ++- .../internal/c_api/AbstractTF_Function.java | 25 ++ .../org/tensorflow/SavedModelBundleTest.java | 25 -- 9 files changed, 418 insertions(+), 140 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 6e34f4ce13f..c3a93b5a244 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -19,6 +19,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionToFunctionDef; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrValueProto; import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; @@ -44,11 +45,13 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.core.Placeholder; import org.tensorflow.proto.framework.AttrValue; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.FunctionDef; import org.tensorflow.proto.framework.OpDef.ArgDef; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; /** @@ -65,77 +68,6 @@ */ public class ConcreteFunction implements AutoCloseable { - private static TF_Operation outputHandle(Operand operand) { - if (operand == null) { - throw new NullPointerException("Can't get output handle for null operand"); - } - - Pointer handle = operand.asOutput().getUnsafeNativeHandle(); - if (handle.isNull()) { - throw new NullPointerException("Native handle of operand is null, has it been closed?"); - } - - if (!(handle instanceof TF_Operation)) { - throw new IllegalArgumentException("Operand was not a graph operand"); - } - - return (TF_Operation) handle; - } - - private static TF_Output resolveToOutput(Graph graph, List> operands) { - TF_Output handles = new TF_Output(operands.size()); - for (int i = 0; i < operands.size(); i++) { - Operand input = operands.get(i); - graph.checkInput(input); - TF_Operation handle = outputHandle(input); - handles.position(i).oper(handle).index(input.asOutput().index()); - } - handles.position(0); - return handles; - } - - private static TF_Function createNative(Graph graph, Signature signature) { - try (PointerScope scope = new PointerScope(); - Reference ref = graph.ref()) { - TF_Status status = TF_Status.newStatus(); - - List> inputs = signature.getInputs().values().stream() - .map((x) -> graph.outputOrError(x.name)) - .collect(Collectors.toList()); - - List> outputs = signature.getOutputs().values().stream() - .map((x) -> graph.outputOrError(x.name)) - .collect(Collectors.toList()); - - List ops = new ArrayList<>( - graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs), - null, Collections.singleton(Placeholder.OP_NAME))); - - PointerPointer operations = new PointerPointer<>(ops.size()); - for (int i = 0; i < ops.size(); i++) { - operations.put(i, ops.get(i).getUnsafeNativeHandle()); - } - - TF_Function handle = TF_GraphToFunction( - ref.nativeHandle(), - new BytePointer(signature.key()), //TODO or methodName? - (byte) 1, - ops.size(), - operations, - inputs.size(), - resolveToOutput(graph, inputs), - outputs.size(), - resolveToOutput(graph, outputs), - null, - null, - new BytePointer(signature.methodName() != null ? signature.methodName() : "Method " + signature.key()), - status - ); - - status.throwExceptionIfNotOK(); - return handle; - } - } /** * Creates a function by building a new graph. @@ -173,7 +105,7 @@ public static ConcreteFunction create(Function functionBuilder) try (Graph graph = new Graph()) { Ops tf = Ops.create(graph); Signature signature = functionBuilder.apply(tf); - return new ConcreteFunction(signature, graph); + return buildFromGraph(graph, signature); } } @@ -203,7 +135,7 @@ public static ConcreteFunction create(Function functionBuilder) * @return a new function */ public static ConcreteFunction create(Signature signature, Graph graph) { - return new ConcreteFunction(signature, graph); + return buildFromGraph(graph, signature); } /** @@ -237,7 +169,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) { * @return a new function */ public static ConcreteFunction create(Signature signature, Session session) { - return new ConcreteFunction(signature, session.graph()); + return buildFromGraph(session.graph(), signature); } /** @@ -256,6 +188,8 @@ public String getNativeFunctionName() { } } + public static final String CALL_OP = "PartitionedCall"; + public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall"; public Map> call(Scope scope, Map> arguments) { @@ -274,10 +208,17 @@ public Map> call(Scope scope, scope.env().attachFunction(this); String name = getNativeFunctionName(); - OperationBuilder opBuilder = scope.env().opBuilder(name, scope.makeOpName(name)); - for (Operand input : inputList) { - opBuilder.addInput(input.asOutput()); - } + String displayName = Scope.isValidOpName(name) ? name : "FunctionCall"; + + OperationBuilder opBuilder = scope.env() + .opBuilder(stateful ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName)); + + opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new)); + + opBuilder.setFunctionName("f", name); + opBuilder.setAttr("Tin", inputList.stream().map(x -> x.asOutput().dataType()).toArray(DataType[]::new)); + opBuilder.setAttr("Tout", signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new)); + opBuilder = scope.apply(opBuilder); Operation op = opBuilder.build(); @@ -413,6 +354,10 @@ public void save(String exportDir) throws IOException { SavedModelBundle.exporter(exportDir).withFunction(this).export(); } + public boolean isStateful() { + return stateful; + } + @Override public void close() { scope.close(); @@ -464,31 +409,35 @@ TF_Function gradNativeHandle() { private final Signature signature; private final TF_Function nativeHandle; private final PointerScope scope; + private final boolean stateful; - ConcreteFunction(Signature signature, Graph graph) { - this(signature, createNative(graph, signature)); - } - - ConcreteFunction(Signature signature, TF_Function nativeHandle) { + ConcreteFunction(Signature signature, TF_Function nativeHandle, boolean stateful) { this.signature = signature; - scope = new PointerScope(); - this.nativeHandle = nativeHandle; - scope.attach(nativeHandle.withDeallocator()); + try (PointerScope scope = new PointerScope()) { + scope.extend(); + this.nativeHandle = nativeHandle.withDeallocator(); + scope.attach(nativeHandle); + this.scope = scope; + } + this.stateful = stateful; } /** * Detects the signature from the handle */ static ConcreteFunction fromNativeHandle(TF_Function function) { - TF_Buffer funcDefBuffer = TF_Buffer.newBuffer(); - TF_Status status2 = TF_Status.newStatus(); - TF_FunctionToFunctionDef(function, funcDefBuffer, status2); - status2.throwExceptionIfNotOK(); + FunctionDef funcDef = null; - try { - funcDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer()); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException("Failed to parse FunctionDef proto", e); + try (PointerScope scope = new PointerScope()) { + TF_Buffer funcDefBuffer = TF_Buffer.newBuffer(); + TF_Status status2 = TF_Status.newStatus(); + TF_FunctionToFunctionDef(function, funcDefBuffer, status2); + status2.throwExceptionIfNotOK(); + try { + funcDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException("Failed to parse FunctionDef proto", e); + } } Signature.Builder builder = Signature.builder().methodName(funcDef.getSignature().getName()) @@ -514,6 +463,158 @@ static ConcreteFunction fromNativeHandle(TF_Function function) { builder.output(outputDef.getName(), info); } - return new ConcreteFunction(builder.build(), function); + return new ConcreteFunction( + builder.build(), + function, + funcDef.getNodeDefList().stream().anyMatch(x -> TensorFlow.isOpStateful(x.getOp())) + ); + } + + + private static TF_Operation outputHandle(Operand operand) { + if (operand == null) { + throw new NullPointerException("Can't get output handle for null operand"); + } + + Pointer handle = operand.asOutput().getUnsafeNativeHandle(); + if (handle.isNull()) { + throw new NullPointerException("Native handle of operand is null, has it been closed?"); + } + + if (!(handle instanceof TF_Operation)) { + throw new IllegalArgumentException("Operand was not a graph operand"); + } + + return (TF_Operation) handle; + } + + private static TF_Output resolveToOutput(Graph graph, List> operands) { + TF_Output handles = new TF_Output(operands.size()); + for (int i = 0; i < operands.size(); i++) { + Operand input = operands.get(i); + graph.checkInput(input); + TF_Operation handle = outputHandle(input); + + handles.position(i).oper(handle).index(input.asOutput().index()); + } + handles.position(0); + return handles; + } + + /** + * Returns the function name if {@code op} is a function call op, or null otherwise. + */ + static String findFunctionCall(GraphOperation op) { + if (op.type().equals(STATEFUL_CALL_OP) || op.type().equals(CALL_OP)) { + try (PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TF_Buffer buff = TF_Buffer.newBuffer(); + TF_OperationGetAttrValueProto(op.getUnsafeNativeHandle(), "f", buff, status); + status.throwExceptionIfNotOK(); + AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer()); + + return def.getFunc().getName(); + } catch (InvalidProtocolBufferException e) { + return null; + } + } + + return null; + } + + private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) { + try (PointerScope scope = new PointerScope(); + Reference ref = graph.ref()) { + TF_Status status = TF_Status.newStatus(); + + List> inputs = signature.getInputs().values().stream() + .map((x) -> graph.outputOrError(x.name)) + .collect(Collectors.toList()); + + List> outputs = signature.getOutputs().values().stream() + .map((x) -> graph.outputOrError(x.name)) + .collect(Collectors.toList()); + + List ops = new ArrayList<>( + graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs), + null, Collections.singleton(Placeholder.OP_NAME))); + + inputs.forEach(input -> ops.remove(input.op())); + + // Python sometimes has NoOps as outputs + Ops tf = Ops.create(graph).withSubScope("functionControlOutputs"); + for (int i = 0; i < outputs.size(); i++) { + Operand output = outputs.get(i); + if (output.op().numOutputs() < 1) { + Operand realOutput = tf + .withControlDependencies(Collections.singletonList(output)) + .withName(output.op().name() + "_control") + .constant(false); + ops.add((GraphOperation) realOutput.op()); + outputs.set(i, realOutput); + } + } + + PointerPointer operations = new PointerPointer<>(ops.size()); + for (int i = 0; i < ops.size(); i++) { + operations.put(i, ops.get(i).getUnsafeNativeHandle()); + } + + TF_Function handle = TF_GraphToFunction( + ref.nativeHandle(), + new BytePointer(signature.key()), + (byte) 1, + ops.size(), + operations, + inputs.size(), + resolveToOutput(graph, inputs), + outputs.size(), + resolveToOutput(graph, outputs), + null, + null, + new BytePointer(signature.methodName() != null ? signature.methodName() : "Method " + signature.key()), + status + ); + + status.throwExceptionIfNotOK(); + return new ConcreteFunction(signature, handle, ops.stream().anyMatch(x -> TensorFlow.isOpStateful(x.type()))); + } + } + + ConcreteFunction withNewSignature(Signature signature) { + if (this.signature.getInputs().size() != signature.getInputs().size()) { + throw new IllegalArgumentException( + "New signature must have the same number of inputs. Expected " + this.signature.getInputs().size() + ", got " + + signature.getInputs().size()); + } + + if (this.signature.getOutputs().size() != signature.getOutputs().size()) { + throw new IllegalArgumentException( + "New signature must have the same number of inputs. Expected " + this.signature.getInputs().size() + ", got " + + signature.getInputs().size()); + } + + List inputs = this.signature.getInputs().values().stream().map(x -> x.dataType) + .collect(Collectors.toList()); + List newInputs = signature.getInputs().values().stream().map(x -> x.dataType) + .collect(Collectors.toList()); + + if (!inputs.equals(newInputs)) { + throw new IllegalArgumentException( + "Data types of the new signature's inputs (in order) must match. Expected " + inputs + ", got " + newInputs); + } + + List outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType) + .collect(Collectors.toList()); + List newOutputs = signature.getOutputs().values().stream().map(x -> x.dataType) + .collect(Collectors.toList()); + + if (!outputs.equals(newOutputs)) { + throw new IllegalArgumentException( + "Data types of the new signature's outputs (in order) must match. Expected " + outputs + ", got " + + newOutputs); + } + + return new ConcreteFunction(signature, nativeHandle, stateful); } } \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index f1dd6216a79..87bd6ee21c1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -22,6 +22,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrBoolList; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloat; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloatList; +import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionName; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrInt; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrIntList; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrShape; @@ -217,6 +218,12 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) { return this; } + @Override + public OperationBuilder setFunctionName(String attrName, String functionName) { + setAttrFunctionName(opHandle, attrName, functionName); + return this; + } + private TFE_Op opHandle; private final EagerSession session; @@ -409,7 +416,14 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes } TF_Status status = TF_Status.newStatus(); TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims), - numDims.length, status); + numDims.length, status); + } + } + + private static void setAttrFunctionName(TFE_Op opHandle, String attrName, String functionName) { + requireOp(opHandle); + try (PointerScope scope = new PointerScope()) { + TFE_OpSetAttrFunctionName(opHandle, attrName, functionName, functionName.length()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 2df714caa16..6bdb4a4a625 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -1163,7 +1163,7 @@ private static SaverDef addVariableSaver(Graph graph) { if (varNames.isEmpty()) { return SaverDef.newBuilder() .setFilenameTensorName(saveFilename.op().name()) - .setSaveTensorName(tf.withName("empty_save").noOp().op().name()) + .setSaveTensorName(tf.withName("empty_save").identity(saveFilename).op().name()) .setRestoreOpName(tf.withName("restore_all").noOp().op().name()) .build(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 72858ece572..3d638069fee 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -24,6 +24,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrBoolList; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFloat; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFloatList; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFuncName; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrInt; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrIntList; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrShape; @@ -45,6 +46,7 @@ import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.SizeTPointer; +import org.tensorflow.Graph.Reference; import org.tensorflow.internal.c_api.TF_Graph; import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_OperationDescription; @@ -344,6 +346,14 @@ public GraphOperationBuilder setAttr(String name, String[] value) { return this; } + @Override + public OperationBuilder setFunctionName(String attrName, String functionName) { + try (Reference r = graph.ref()) { + setAttrFunctionName(unsafeNativeHandle, attrName, functionName); + } + return this; + } + private TF_OperationDescription unsafeNativeHandle; private Graph graph; @@ -539,4 +549,11 @@ private static void setAttrStringList(TF_OperationDescription handle, String nam TF_SetAttrStringList(handle, new BytePointer(name), valuePointers, lengths, value.length); } } + + private static void setAttrFunctionName(TF_OperationDescription opHandle, String attrName, String functionName) { + requireHandle(opHandle); + try (PointerScope scope = new PointerScope()) { + TF_SetAttrFuncName(opHandle, attrName, functionName, functionName.length()); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java index a487d8b9237..dca709389f3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java @@ -214,7 +214,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Shape value); + OperationBuilder setAttr(String name, Shape value); /** * Set the shape values of an attribute of the operation being built. @@ -223,5 +223,14 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Shape[] value); + OperationBuilder setAttr(String name, Shape[] value); + + /** + * Set a function name attribute of the operation being build. + * + * @param attrName the attribute to set + * @param functionName the function name + * @return the OperationBuilder instance for chaining. + */ + OperationBuilder setFunctionName(String attrName, String functionName); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 004c4b43d6a..8b9d87cb2f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -17,6 +17,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrValueProto; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; @@ -25,11 +26,14 @@ import java.io.OutputStream; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; @@ -37,14 +41,18 @@ import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.c_api.TF_Buffer; import org.tensorflow.internal.c_api.TF_Graph; +import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_Session; import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.proto.framework.AttrValue; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SavedModel; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.util.SaverDef; /** @@ -300,7 +308,9 @@ public List signatures() { } /** - * Return a {@link ConcreteFunction} corresponding to the function signature. + * Return a {@link ConcreteFunction} corresponding to the function signature. The function may depend on other + * functions in the bundle, which will need to be attached to the execution environment used to call this function (or + * the default eager environment if called with tensors). * *

{@code
    * ConcreteFunction myFunction = savedModelBundle.function("mySignatureKey");
@@ -320,6 +330,13 @@ public ConcreteFunction function(String signatureKey) {
     return function;
   }
 
+  /**
+   * Get all functions in the bundle.
+   */
+  public List functions() {
+    return new ArrayList<>(functions.values());
+  }
+
   /**
    * Invokes the default function directly from this model.
    *
@@ -371,6 +388,53 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef
     this.functions = functions;
   }
 
+  private static final Pattern INFERENCE_FUNCTION_NAME_PATTERN = Pattern
+      .compile("__inference_(.+)_\\d+", Pattern.DOTALL);
+
+  /**
+   * Check that all outputs of the signature come from a single call op that takes the inputs.
+   */
+  private static GraphOperation findFunctionWrapper(Graph graph, SignatureDef signatureDef) {
+
+    GraphOperation callOp = null;
+    for (TensorInfo output : signatureDef.getOutputsMap().values()) {
+      GraphOperation op = (GraphOperation) graph.outputOrError(output.getName()).op();
+      if (callOp == null) {
+        callOp = op;
+      } else if (!callOp.equals(op)) {
+        return null;
+      }
+    }
+
+    if (callOp == null) {
+      return null;
+    }
+
+    if (callOp != null) {
+
+      if (callOp.numInputs() != signatureDef.getInputsCount() || callOp.numOutputs() != signatureDef
+          .getOutputsCount()) {
+        return null;
+      }
+
+      int i = 0;
+      List> opInputs = callOp.inputs();
+
+      for (TensorInfo input : signatureDef.getInputsMap().values()) {
+        if (!graph.outputOrError(input.getName()).equals(opInputs.get(i))) {
+          return null;
+        }
+        i++;
+      }
+    }
+
+    if (!callOp.type().equals(ConcreteFunction.CALL_OP) && !callOp.type().equals(ConcreteFunction.STATEFUL_CALL_OP)) {
+      return null;
+    }
+
+    return callOp;
+  }
+
   /**
    * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the
    * MetaGraphDef.
@@ -388,9 +452,63 @@ private static SavedModelBundle fromHandle(
     // that the functions do not need to be closed by the user and if it does, it should have
     // no effect.
     final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount());
+    List graphFunctions = graph.getFunctions();
     metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> {
-      Signature signature = new Signature(signatureName, signatureDef);
-      functions.put(signatureName, ConcreteFunction.create(signature, session));
+
+      GraphOperation callOp = findFunctionWrapper(graph, signatureDef);
+
+      // if the function is a thin wrapper around a function call, unwrap it
+      if (callOp != null) {
+
+        //TODO my problem is with __inference_signature_wrapper_66
+
+        try (PointerScope scope = new PointerScope()) {
+          TF_Operation op = ((GraphOperation) graph
+              .outputOrError(signatureDef.getOutputsMap().values().iterator().next().getName()).op())
+              .getUnsafeNativeHandle();
+          TF_Status status = TF_Status.newStatus();
+          TF_Buffer buff = TF_Buffer.newBuffer();
+          TF_OperationGetAttrValueProto(op, "f", buff, status);
+          status.throwExceptionIfNotOK();
+          AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
+
+          String functionName = def.getFunc().getName();
+
+          ConcreteFunction function = null;
+          for (ConcreteFunction fn : graphFunctions) {
+            if (fn.getNativeFunctionName().equals(functionName)) {
+              function = fn;
+              break;
+            }
+          }
+
+          if (function != null) {
+            functions.put(signatureName, function.withNewSignature(new Signature(signatureName, signatureDef)));
+          }
+        } catch (InvalidProtocolBufferException | IllegalArgumentException ignored) {
+
+        }
+      }
+
+      // try to do the unwrapping based on name if there are no outputs (and thus we can't find the call op)
+      if (!functions.containsKey(signatureName) && signatureDef.getOutputsCount() < 1) {
+        for (ConcreteFunction fn : graphFunctions) {
+          Matcher matcher = INFERENCE_FUNCTION_NAME_PATTERN.matcher(fn.getNativeFunctionName());
+          if (matcher.find()) {
+            String fnName = matcher.group(1);
+            if (fnName.equals(signatureName)) {
+              functions.put(signatureName, fn);
+              break;
+            }
+          }
+        }
+      }
+
+      // otherwise use the wrapper
+      if (!functions.containsKey(signatureName)) {
+        Signature signature = new Signature(signatureName, signatureDef);
+        functions.put(signatureName, ConcreteFunction.create(signature, session));
+      }
     });
     return new SavedModelBundle(graph, session, metaGraphDef, functions);
   }
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java
index de481d256a3..f5edfec09a2 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java
@@ -23,6 +23,8 @@
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version;
 
 import com.google.protobuf.InvalidProtocolBufferException;
+import java.util.Set;
+import java.util.stream.Collectors;
 import org.bytedeco.javacpp.PointerScope;
 import org.tensorflow.exceptions.TensorFlowException;
 import org.tensorflow.internal.c_api.TF_Buffer;
@@ -30,10 +32,14 @@
 import org.tensorflow.internal.c_api.TF_Status;
 import org.tensorflow.proto.framework.OpList;
 
-/** Static utility methods describing the TensorFlow runtime. */
+/**
+ * Static utility methods describing the TensorFlow runtime.
+ */
 public final class TensorFlow {
 
-  /** Returns the version of the underlying TensorFlow runtime. */
+  /**
+   * Returns the version of the underlying TensorFlow runtime.
+   */
   public static String version() {
     return TF_Version().getString();
   }
@@ -41,11 +47,10 @@ public static String version() {
   /**
    * All the TensorFlow operations available in this address space.
    *
-   * @return A OpList
-   *     protocol buffer, which lists all the available TensorFlow operations.
+   * @return A OpList protocol
+   * buffer, which lists all the available TensorFlow operations.
    */
-  public static OpList registeredOpList() {
+  public static synchronized OpList registeredOpList() {
     TF_Buffer buf = TF_GetAllOpList();
     try {
       return OpList.parseFrom(buf.dataAsByteBuffer());
@@ -56,14 +61,25 @@ public static OpList registeredOpList() {
     }
   }
 
+  private static Set statefulOps;
+
+  public static synchronized boolean isOpStateful(String opType) {
+    if (statefulOps == null) {
+      statefulOps = registeredOpList().getOpList().stream()
+          .filter(x -> x.getIsStateful())
+          .map(x -> x.getName())
+          .collect(Collectors.toSet());
+    }
+
+    return statefulOps.contains(opType);
+  }
+
   /**
-   * Load the dynamic library in filename and register the operations and kernels present in that
-   * library.
+   * Load the dynamic library in filename and register the operations and kernels present in that library.
    *
    * @param filename Path of the dynamic library containing operations and kernels to load.
-   * @return A OpList
-   *     protocol buffer message defining the operations defined in the library.
+   * @return A OpList protocol
+   * buffer message defining the operations defined in the library.
    * @throws UnsatisfiedLinkError if filename cannot be loaded.
    */
   public static OpList loadLibrary(String filename) {
@@ -104,9 +120,12 @@ private static OpList libraryOpList(TF_Library handle) {
     }
   }
 
-  private TensorFlow() {}
+  private TensorFlow() {
+  }
 
-  /** Load the TensorFlow runtime C library. */
+  /**
+   * Load the TensorFlow runtime C library.
+   */
   static {
     try {
       NativeLibrary.load();
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java
index 0d021244c6b..707012e7aaa 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java
@@ -19,7 +19,9 @@
 
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction;
 
+import java.util.Iterator;
 import org.bytedeco.javacpp.Pointer;
+import org.bytedeco.javacpp.PointerScope;
 import org.bytedeco.javacpp.annotation.Properties;
 
 @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
@@ -44,7 +46,30 @@ public AbstractTF_Function(Pointer p) {
     super(p);
   }
 
+
+  private boolean hasDeallocator = false;
+
+  /**
+   * Adds a deallocator if there isn't already one.  Attaches to the current scope regardless.
+   */
   public TF_Function withDeallocator() {
+    if (hasDeallocator) {
+      Iterator it = PointerScope.getScopeIterator();
+      if (it != null) {
+        while (it.hasNext()) {
+          try {
+            it.next().attach(this);
+          } catch (IllegalArgumentException e) {
+            // try the next scope down the stack
+            continue;
+          }
+          break;
+        }
+      }
+
+      return (TF_Function) this;
+    }
+    hasDeallocator = true;
     return this.deallocator(new DeleteDeallocator((TF_Function) this));
   }
 
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
index 9cd4b3d4c87..f93826c6092 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -29,7 +29,6 @@
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
-import org.junit.Ignore;
 import org.junit.jupiter.api.Test;
 import org.tensorflow.exceptions.TensorFlowException;
 import org.tensorflow.ndarray.FloatNdArray;
@@ -213,30 +212,6 @@ public void exportMultipleFunctions() throws IOException {
     }
   }
 
-  @Test
-  @Ignore // this is supported now
-  public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException {
-    Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
-    try (Graph g = new Graph()) {
-      Ops tf = Ops.create(g);
-      Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1));
-      Signature f2Signature = buildIdentityGraph(tf, "identity");
-      try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g);
-          ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) {
-        //TODO f1.session().run(Init.DEFAULT_NAME);
-        try {
-          SavedModelBundle.exporter(testFolder.toString())
-              .withFunction(f1)
-              .withFunction(f2)
-              .export();
-          fail();
-        } catch (UnsupportedOperationException e) {
-          // as expected
-        }
-      }
-    }
-  }
-
   @Test
   public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOException {
     Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");

From abf0ed1e7dd70aa1a7c8ea2bae4afa677f0f282a Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Fri, 5 Mar 2021 19:57:16 -0800
Subject: [PATCH 09/34] Start of dependencies

Signed-off-by: Ryan Nett 
---
 .../java/org/tensorflow/ConcreteFunction.java | 71 +++++++++++++------
 1 file changed, 50 insertions(+), 21 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
index c3a93b5a244..02ba4c535ea 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
@@ -27,8 +27,10 @@
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BytePointer;
@@ -47,6 +49,7 @@
 import org.tensorflow.proto.framework.AttrValue;
 import org.tensorflow.proto.framework.DataType;
 import org.tensorflow.proto.framework.FunctionDef;
+import org.tensorflow.proto.framework.NodeDef;
 import org.tensorflow.proto.framework.OpDef.ArgDef;
 import org.tensorflow.proto.framework.SignatureDef;
 import org.tensorflow.proto.framework.TensorInfo;
@@ -501,27 +504,6 @@ private static TF_Output resolveToOutput(Graph graph, List> operands)
     return handles;
   }
 
-  /**
-   * Returns the function name if {@code op} is a function call op, or null otherwise.
-   */
-  static String findFunctionCall(GraphOperation op) {
-    if (op.type().equals(STATEFUL_CALL_OP) || op.type().equals(CALL_OP)) {
-      try (PointerScope scope = new PointerScope()) {
-        TF_Status status = TF_Status.newStatus();
-        TF_Buffer buff = TF_Buffer.newBuffer();
-        TF_OperationGetAttrValueProto(op.getUnsafeNativeHandle(), "f", buff, status);
-        status.throwExceptionIfNotOK();
-        AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
-
-        return def.getFunc().getName();
-      } catch (InvalidProtocolBufferException e) {
-        return null;
-      }
-    }
-
-    return null;
-  }
-
   private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) {
     try (PointerScope scope = new PointerScope();
         Reference ref = graph.ref()) {
@@ -617,4 +599,51 @@ ConcreteFunction withNewSignature(Signature signature) {
 
     return new ConcreteFunction(signature, nativeHandle, stateful);
   }
+
+  /**
+   * Returns the function name if {@code op} is a function call op, or null otherwise.
+   */
+  static String findFunctionCall(GraphOperation op) {
+    if (op.type().equals(STATEFUL_CALL_OP) || op.type().equals(CALL_OP)) {
+      try (PointerScope scope = new PointerScope()) {
+        TF_Status status = TF_Status.newStatus();
+        TF_Buffer buff = TF_Buffer.newBuffer();
+        TF_OperationGetAttrValueProto(op.getUnsafeNativeHandle(), "f", buff, status);
+        status.throwExceptionIfNotOK();
+        AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
+
+        return def.getFunc().getName();
+      } catch (InvalidProtocolBufferException e) {
+        return null;
+      }
+    }
+
+    return null;
+  }
+
+  FunctionDef functionDef() {
+    try (PointerScope scope = new PointerScope()) {
+      TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
+      TF_Status status2 = TF_Status.newStatus();
+      TF_FunctionToFunctionDef(nativeHandle(), funcDefBuffer, status2);
+      status2.throwExceptionIfNotOK();
+      try {
+        return FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
+      } catch (InvalidProtocolBufferException e) {
+        throw new IllegalStateException("Failed to parse FunctionDef proto", e);
+      }
+    }
+  }
+
+  static List findDependencies(FunctionDef def) {
+    Set deps = new LinkedHashSet<>();
+
+    for (NodeDef node : def.getNodeDefList()) {
+      if (node.getOp().equals(CALL_OP) || node.getOp().equals(STATEFUL_CALL_OP)) {
+        deps.add(node.getAttrMap().get("f").getFunc().getName());
+      }
+    }
+
+    return new ArrayList<>(deps);
+  }
 }
\ No newline at end of file

From ba65103bcf3dda1af155942f01be1dc5688ffe2a Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Sat, 6 Mar 2021 15:38:52 -0800
Subject: [PATCH 10/34] Support dependencies

Signed-off-by: Ryan Nett 
---
 .../java/org/tensorflow/ConcreteFunction.java | 248 ++++++++----------
 .../java/org/tensorflow/EagerSession.java     |  12 +-
 .../org/tensorflow/ExecutionEnvironment.java  |   3 +-
 .../src/main/java/org/tensorflow/Graph.java   |  74 +++---
 .../java/org/tensorflow/NativeFunction.java   | 162 ++++++++++++
 .../java/org/tensorflow/SavedModelBundle.java |  31 +--
 .../internal/c_api/AbstractTF_Function.java   |   2 +-
 .../org/tensorflow/ConcreteFunctionTest.java  |  39 ++-
 8 files changed, 375 insertions(+), 196 deletions(-)
 create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
index 02ba4c535ea..8ab491c8d57 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
@@ -15,19 +15,15 @@
  */
 package org.tensorflow;
 
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionToFunctionDef;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrValueProto;
 
-import com.google.protobuf.InvalidProtocolBufferException;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
-import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -38,7 +34,6 @@
 import org.bytedeco.javacpp.PointerPointer;
 import org.bytedeco.javacpp.PointerScope;
 import org.tensorflow.Graph.Reference;
-import org.tensorflow.internal.c_api.TF_Buffer;
 import org.tensorflow.internal.c_api.TF_Function;
 import org.tensorflow.internal.c_api.TF_Operation;
 import org.tensorflow.internal.c_api.TF_Output;
@@ -49,7 +44,6 @@
 import org.tensorflow.proto.framework.AttrValue;
 import org.tensorflow.proto.framework.DataType;
 import org.tensorflow.proto.framework.FunctionDef;
-import org.tensorflow.proto.framework.NodeDef;
 import org.tensorflow.proto.framework.OpDef.ArgDef;
 import org.tensorflow.proto.framework.SignatureDef;
 import org.tensorflow.proto.framework.TensorInfo;
@@ -186,9 +180,35 @@ public Signature signature() {
    * Get the name of the function.
    */
   public String getNativeFunctionName() {
-    try (PointerScope scope = new PointerScope()) {
-      return TF_FunctionName(nativeHandle()).getString();
-    }
+    return nativeFunction.getName();
+  }
+
+  /**
+   * Get the {@link FunctionDef} proto.
+   */
+  public FunctionDef getFunctionDef() {
+    return nativeFunction.getFunctionDef();
+  }
+
+  /**
+   * Get whether the function is stateful.
+   */
+  public boolean isStateful() {
+    return nativeFunction.isStateful();
+  }
+
+  Set getDependencies() {
+    return dependencies;
+  }
+
+  @Override
+  public void close() {
+    scope.close();
+  }
+
+  @Override
+  public String toString() {
+    return signature.toString();
   }
 
   public static final String CALL_OP = "PartitionedCall";
@@ -214,7 +234,7 @@ public Map> call(Scope scope,
     String displayName = Scope.isValidOpName(name) ? name : "FunctionCall";
 
     OperationBuilder opBuilder = scope.env()
-        .opBuilder(stateful ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));
+        .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));
 
     opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new));
 
@@ -357,20 +377,6 @@ public void save(String exportDir) throws IOException {
     SavedModelBundle.exporter(exportDir).withFunction(this).export();
   }
 
-  public boolean isStateful() {
-    return stateful;
-  }
-
-  @Override
-  public void close() {
-    scope.close();
-  }
-
-  @Override
-  public String toString() {
-    return signature.toString();
-  }
-
   /**
    * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
    * JIT is extremely non-obvious.
@@ -394,10 +400,10 @@ private void makeJit() {
   }
 
   TF_Function nativeHandle() {
-    if (nativeHandle.isNull()) {
+    if (nativeFunction.getNativeHandle().isNull()) {
       throw new IllegalStateException("Function has been closed");
     }
-    return nativeHandle;
+    return nativeFunction.getNativeHandle();
   }
 
   /**
@@ -410,43 +416,93 @@ TF_Function gradNativeHandle() {
   }
 
   private final Signature signature;
-  private final TF_Function nativeHandle;
+  private final NativeFunction nativeFunction;
   private final PointerScope scope;
-  private final boolean stateful;
+  private final Set dependencies;
+
+  ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection availableFunctions) {
+    this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions));
+  }
+
+  private static boolean dataTypesMatch(List a, List b) {
+    if (a.size() != b.size()) {
+      return false;
+    }
+
+    for (int i = 0; i < a.size(); i++) {
+      DataType aType = a.get(i);
+      DataType bType = b.get(i);
+
+      if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) {
+        return false;
+      }
+    }
 
-  ConcreteFunction(Signature signature, TF_Function nativeHandle, boolean stateful) {
+    return true;
+  }
+
+  private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set dependencies) {
     this.signature = signature;
+    this.nativeFunction = nativeFunction;
+    this.dependencies = Collections.unmodifiableSet(dependencies);
+
+    if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
+      throw new IllegalArgumentException(
+          "Signature must have the same number of inputs as the native function.  Expected "
+              + nativeFunction.getFunctionDef().getSignature().getInputArgCount() + ", got "
+              + this.signature.getInputs().size());
+    }
+
+    if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
+      throw new IllegalArgumentException(
+          "New signature must have the same number of outputs as the native function.  Expected "
+              + nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + ", got "
+              + this.signature.getOutputs().size());
+    }
+
+    List inputs = this.signature.getInputs().values().stream().map(x -> x.dataType)
+        .collect(Collectors.toList());
+    List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
+        .map(ArgDef::getType)
+        .collect(Collectors.toList());
+
+    if (!dataTypesMatch(inputs, nativeInputs)) {
+      throw new IllegalArgumentException(
+          "Data types of the signature's inputs must match the native function's (in order).  Expected "
+              + nativeInputs + ", got " + inputs);
+    }
+
+    List outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType)
+        .collect(Collectors.toList());
+    List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
+        .map(ArgDef::getType)
+        .collect(Collectors.toList());
+
+    if (!dataTypesMatch(outputs, nativeOutputs)) {
+      throw new IllegalArgumentException(
+          "Data types of the signature's outputs must match the native function's (in order).  Expected "
+              + nativeOutputs + ", got "
+              + outputs);
+    }
+
     try (PointerScope scope = new PointerScope()) {
-      scope.extend();
-      this.nativeHandle = nativeHandle.withDeallocator();
-      scope.attach(nativeHandle);
       this.scope = scope;
+      scope.extend();
+      this.nativeFunction.getNativeHandle().withDeallocatorInScope();
+      this.dependencies.forEach(TF_Function::withDeallocatorInScope);
     }
-    this.stateful = stateful;
   }
 
   /**
    * Detects the signature from the handle
    */
-  static ConcreteFunction fromNativeHandle(TF_Function function) {
+  static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
+      Collection availableFunctions) {
 
-    FunctionDef funcDef = null;
-    try (PointerScope scope = new PointerScope()) {
-      TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
-      TF_Status status2 = TF_Status.newStatus();
-      TF_FunctionToFunctionDef(function, funcDefBuffer, status2);
-      status2.throwExceptionIfNotOK();
-      try {
-        funcDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
-      } catch (InvalidProtocolBufferException e) {
-        throw new IllegalStateException("Failed to parse FunctionDef proto", e);
-      }
-    }
-
-    Signature.Builder builder = Signature.builder().methodName(funcDef.getSignature().getName())
-        .key(TF_FunctionName(function).getString());
+    Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName())
+        .key(nativeFunction.getName());
 
-    for (ArgDef input : funcDef.getSignature().getInputArgList()) {
+    for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
       TensorInfo info = TensorInfo.newBuilder()
           .setDtype(input.getType())
           .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
@@ -456,7 +512,7 @@ static ConcreteFunction fromNativeHandle(TF_Function function) {
       builder.input(input.getName(), info);
     }
 
-    for (ArgDef outputDef : funcDef.getSignature().getOutputArgList()) {
+    for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) {
       TensorInfo info = TensorInfo.newBuilder()
           .setDtype(outputDef.getType())
           .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
@@ -468,8 +524,8 @@ static ConcreteFunction fromNativeHandle(TF_Function function) {
 
     return new ConcreteFunction(
         builder.build(),
-        function,
-        funcDef.getNodeDefList().stream().anyMatch(x -> TensorFlow.isOpStateful(x.getOp()))
+        nativeFunction,
+        availableFunctions
     );
   }
 
@@ -559,91 +615,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
       );
 
       status.throwExceptionIfNotOK();
-      return new ConcreteFunction(signature, handle, ops.stream().anyMatch(x -> TensorFlow.isOpStateful(x.type())));
+      return new ConcreteFunction(signature, new NativeFunction(handle), graph.getNativeFunctions());
     }
   }
 
   ConcreteFunction withNewSignature(Signature signature) {
-    if (this.signature.getInputs().size() != signature.getInputs().size()) {
-      throw new IllegalArgumentException(
-          "New signature must have the same number of inputs.  Expected " + this.signature.getInputs().size() + ", got "
-              + signature.getInputs().size());
-    }
-
-    if (this.signature.getOutputs().size() != signature.getOutputs().size()) {
-      throw new IllegalArgumentException(
-          "New signature must have the same number of inputs.  Expected " + this.signature.getInputs().size() + ", got "
-              + signature.getInputs().size());
-    }
-
-    List inputs = this.signature.getInputs().values().stream().map(x -> x.dataType)
-        .collect(Collectors.toList());
-    List newInputs = signature.getInputs().values().stream().map(x -> x.dataType)
-        .collect(Collectors.toList());
-
-    if (!inputs.equals(newInputs)) {
-      throw new IllegalArgumentException(
-          "Data types of the new signature's inputs (in order) must match.  Expected " + inputs + ", got " + newInputs);
-    }
-
-    List outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType)
-        .collect(Collectors.toList());
-    List newOutputs = signature.getOutputs().values().stream().map(x -> x.dataType)
-        .collect(Collectors.toList());
-
-    if (!outputs.equals(newOutputs)) {
-      throw new IllegalArgumentException(
-          "Data types of the new signature's outputs (in order) must match.  Expected " + outputs + ", got "
-              + newOutputs);
-    }
-
-    return new ConcreteFunction(signature, nativeHandle, stateful);
-  }
-
-  /**
-   * Returns the function name if {@code op} is a function call op, or null otherwise.
-   */
-  static String findFunctionCall(GraphOperation op) {
-    if (op.type().equals(STATEFUL_CALL_OP) || op.type().equals(CALL_OP)) {
-      try (PointerScope scope = new PointerScope()) {
-        TF_Status status = TF_Status.newStatus();
-        TF_Buffer buff = TF_Buffer.newBuffer();
-        TF_OperationGetAttrValueProto(op.getUnsafeNativeHandle(), "f", buff, status);
-        status.throwExceptionIfNotOK();
-        AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
-
-        return def.getFunc().getName();
-      } catch (InvalidProtocolBufferException e) {
-        return null;
-      }
-    }
-
-    return null;
-  }
-
-  FunctionDef functionDef() {
-    try (PointerScope scope = new PointerScope()) {
-      TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
-      TF_Status status2 = TF_Status.newStatus();
-      TF_FunctionToFunctionDef(nativeHandle(), funcDefBuffer, status2);
-      status2.throwExceptionIfNotOK();
-      try {
-        return FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
-      } catch (InvalidProtocolBufferException e) {
-        throw new IllegalStateException("Failed to parse FunctionDef proto", e);
-      }
-    }
-  }
-
-  static List findDependencies(FunctionDef def) {
-    Set deps = new LinkedHashSet<>();
-
-    for (NodeDef node : def.getNodeDefList()) {
-      if (node.getOp().equals(CALL_OP) || node.getOp().equals(STATEFUL_CALL_OP)) {
-        deps.add(node.getAttrMap().get("f").getFunc().getName());
-      }
-    }
-
-    return new ArrayList<>(deps);
+    return new ConcreteFunction(signature, nativeFunction, dependencies);
   }
 }
\ No newline at end of file
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
index 8b592880929..d0a73bf480e 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
@@ -28,6 +28,7 @@
 import org.tensorflow.internal.WeakPointerScope;
 import org.tensorflow.internal.c_api.TFE_Context;
 import org.tensorflow.internal.c_api.TFE_ContextOptions;
+import org.tensorflow.internal.c_api.TF_Function;
 import org.tensorflow.internal.c_api.TF_Status;
 import org.tensorflow.op.Op;
 import org.tensorflow.op.Scope;
@@ -289,12 +290,17 @@ public OperationBuilder opBuilder(String type, String name) {
   public void attachFunction(ConcreteFunction function) {
     checkSession();
     try (PointerScope scope = new PointerScope()) {
-      TF_Status status = TF_Status.newStatus();
-      TFE_ContextAddFunction(nativeHandle, function.nativeHandle(), status);
-      status.throwExceptionIfNotOK();
+      attachNativeFunction(function.nativeHandle());
+      function.getDependencies().forEach(this::attachNativeFunction);
     }
   }
 
+  private void attachNativeFunction(TF_Function fn) {
+    TF_Status status = TF_Status.newStatus();
+    TFE_ContextAddFunction(nativeHandle, fn, status);
+    status.throwExceptionIfNotOK();
+  }
+
   @Override
   public Types environmentType() {
     return Types.EAGER;
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
index 85a1b4a3355..eafc2698789 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
@@ -38,8 +38,7 @@ enum Types {
   OperationBuilder opBuilder(String type, String name);
 
   /**
-   * Attach the function to this execution environment, allowing it to be called by creating an op with the function
-   * name as it's {@code type}.
+   * Attach the function and its dependencies to this execution environment, allowing it to be called.
    *
    * Done automatically in the {@link org.tensorflow.op.Ops#call(ConcreteFunction, java.util.Map)} ops.
    */
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
index 6bdb4a4a625..b2160dcf2d7 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -16,10 +16,8 @@
 package org.tensorflow;
 
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteGraph;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishWhile;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphCopyFunction;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphGetFunctions;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphImportGraphDef;
@@ -389,19 +387,18 @@ public GraphOperationBuilder opBuilder(String type, String name) {
   public void attachFunction(ConcreteFunction function) {
     try (Reference ref = ref();
         PointerScope scope = new PointerScope()) {
-      TF_Status status = TF_Status.newStatus();
-      TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), function.gradNativeHandle(), status);
-      status.throwExceptionIfNotOK();
+      attachNativeFunction(ref.nativeHandle(), function.nativeHandle(), function.gradNativeHandle());
+      function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x, null));
     }
   }
 
-  /**
-   * Get the function attached to the graph with the given native name.  Returns {@code null} if none found.
-   *
-   * @param key the name of the native function.  Note that this may include an argument hash.
-   * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct name
-   */
-  public synchronized ConcreteFunction getFunction(String key) {
+  private void attachNativeFunction(TF_Graph graph, TF_Function fn, TF_Function grad) {
+    TF_Status status = TF_Status.newStatus();
+    TF_GraphCopyFunction(graph, fn, grad, status);
+    status.throwExceptionIfNotOK();
+  }
+
+  synchronized List getNativeFunctions() {
     try (Reference ref = ref();
         PointerScope scope = new PointerScope()) {
       TF_Status status = TF_Status.newStatus();
@@ -413,49 +410,54 @@ public synchronized ConcreteFunction getFunction(String key) {
       TF_GraphGetFunctions(ref.nativeHandle(), output, numFunctions, status);
       status.throwExceptionIfNotOK();
 
-      ConcreteFunction func = null;
-
+      List funcs = new ArrayList<>(numFunctions);
       for (int i = 0; i < numFunctions; i++) {
         TF_Function function = output.get(TF_Function.class, i);
 
-        String functionName = TF_FunctionName(function).getString();
-
-        if (functionName.equals(key) && func == null) {
-          func = ConcreteFunction.fromNativeHandle(function);
-        } else {
-          TF_DeleteFunction(function);
-        }
+        funcs.add(new NativeFunction(function));
       }
 
-      return func;
+      return funcs;
     }
   }
 
   /**
-   * Get the functions attached to the graph.
+   * Get the function attached to the graph with the given native name.  Returns {@code null} if none found.
    *
-   * @return all functions attached to this graph.
+   * @param key the name of the native function.  Note that this may include an argument hash.
+   * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct name
    */
-  public synchronized List getFunctions() {
+  public synchronized ConcreteFunction getFunction(String key) {
     try (Reference ref = ref();
         PointerScope scope = new PointerScope()) {
-      TF_Status status = TF_Status.newStatus();
-
-      int numFunctions = TF_GraphNumFunctions(ref.nativeHandle());
+      List funcs = getNativeFunctions();
 
-      PointerPointer output = new PointerPointer<>(numFunctions);
+      // will close unused functions when method ends
+      funcs.forEach(x -> x.getNativeHandle().withDeallocatorInScope());
 
-      TF_GraphGetFunctions(ref.nativeHandle(), output, numFunctions, status);
-      status.throwExceptionIfNotOK();
+      ConcreteFunction func = null;
 
-      List funcs = new ArrayList<>(numFunctions);
-      for (int i = 0; i < numFunctions; i++) {
-        TF_Function function = output.get(TF_Function.class, i);
+      for (int i = 0; i < funcs.size(); i++) {
 
-        funcs.add(ConcreteFunction.fromNativeHandle(function));
+        if (funcs.get(i).getName().equals(key) && func == null) {
+          func = ConcreteFunction.fromNativeHandle(funcs.get(i), funcs);
+        }
       }
 
-      return funcs;
+      return func;
+    }
+  }
+
+  /**
+   * Get the functions attached to the graph.
+   *
+   * @return all functions attached to this graph.
+   */
+  public synchronized List getFunctions() {
+    try (Reference ref = ref()) {
+      List funcs = getNativeFunctions();
+
+      return funcs.stream().map(x -> ConcreteFunction.fromNativeHandle(x, funcs)).collect(Collectors.toList());
     }
   }
 
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java
new file mode 100644
index 00000000000..b0405e4cf8e
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java
@@ -0,0 +1,162 @@
+/*
+  Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow;
+
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionToFunctionDef;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.bytedeco.javacpp.PointerScope;
+import org.tensorflow.internal.c_api.TF_Buffer;
+import org.tensorflow.internal.c_api.TF_Function;
+import org.tensorflow.internal.c_api.TF_Status;
+import org.tensorflow.proto.framework.FunctionDef;
+import org.tensorflow.proto.framework.NodeDef;
+
+/**
+ * A class holding a native function handle and providing cached access to it's {@link FunctionDef}.
+ */
+class NativeFunction {
+
+  private final TF_Function nativeHandle;
+
+  private FunctionDef functionDef = null;
+  private List dependencies = null;
+  private Boolean stateful = null;
+  private String name = null;
+
+  public NativeFunction(TF_Function nativeHandle) {
+    this.nativeHandle = nativeHandle;
+  }
+
+  /**
+   * Get the native handle.  No guarantees about liveness are made.
+   */
+  public TF_Function getNativeHandle() {
+    return nativeHandle;
+  }
+
+  /**
+   * Get the function's {@link FunctionDef}
+   */
+  public synchronized FunctionDef getFunctionDef() {
+    if (functionDef == null) {
+      try (PointerScope scope = new PointerScope()) {
+        TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
+        TF_Status status = TF_Status.newStatus();
+
+        TF_FunctionToFunctionDef(nativeHandle, funcDefBuffer, status);
+        status.throwExceptionIfNotOK();
+
+        try {
+          functionDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
+        } catch (InvalidProtocolBufferException e) {
+          throw new IllegalStateException("Failed to parse FunctionDef proto", e);
+        }
+      }
+    }
+
+    return functionDef;
+  }
+
+  /**
+   * Get the first-level dependencies of the function.
+   */
+  public synchronized List getDependencies() {
+    if (dependencies == null) {
+      Set deps = new LinkedHashSet<>();
+
+      for (NodeDef node : getFunctionDef().getNodeDefList()) {
+        if (node.getOp().equals(ConcreteFunction.CALL_OP) || node.getOp().equals(ConcreteFunction.STATEFUL_CALL_OP)) {
+          deps.add(node.getAttrMap().get("f").getFunc().getName());
+        }
+      }
+      dependencies = Collections.unmodifiableList(new ArrayList<>(deps));
+    }
+
+    return dependencies;
+  }
+
+  /**
+   * Get whether the function is stateful (whether it has stateful ops).
+   */
+  public synchronized boolean isStateful() {
+    if (stateful == null) {
+      stateful = getFunctionDef().getSignature().getIsStateful()
+          || getFunctionDef().getNodeDefList().stream().anyMatch(x -> TensorFlow.isOpStateful(x.getOp()));
+    }
+    return stateful;
+  }
+
+  /**
+   * Get the name of the function.
+   */
+  public synchronized String getName() {
+    if (name == null) {
+      try (PointerScope scope = new PointerScope()) {
+        return TF_FunctionName(nativeHandle).getString();
+      }
+    }
+
+    return name;
+  }
+
+  synchronized Set getAllDependencies(Collection availableFunctions) {
+    Map fnMap = availableFunctions.stream()
+        .collect(Collectors.toMap(NativeFunction::getName, e -> e));
+    Set done = new LinkedHashSet<>(1 + getDependencies().size());
+
+    Queue todo = new ArrayDeque<>(1 + getDependencies().size());
+    todo.add(this);
+
+    while (!todo.isEmpty()) {
+      NativeFunction next = todo.remove();
+
+      if (!done.add(next.getName())) {
+        continue;
+      }
+
+      for (String dep : next.getDependencies()) {
+        if (!done.contains(dep)) {
+          NativeFunction fn = fnMap.get(dep);
+
+          if (fn == null) {
+            throw new IllegalStateException("Function " + dep + " is required, but not present in graph.");
+          }
+
+          todo.add(fn);
+        }
+      }
+    }
+
+    done.remove(getName());
+
+    return done.stream().map(fnMap::get)
+        .map(NativeFunction::getNativeHandle)
+        .collect(Collectors.toSet());
+  }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
index 8b9d87cb2f7..fb392d195a9 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -32,7 +32,6 @@
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BytePointer;
@@ -460,8 +459,6 @@ private static SavedModelBundle fromHandle(
       // if the function is a thin wrapper around a function call, unwrap it
       if (callOp != null) {
 
-        //TODO my problem is with __inference_signature_wrapper_66
-
         try (PointerScope scope = new PointerScope()) {
           TF_Operation op = ((GraphOperation) graph
               .outputOrError(signatureDef.getOutputsMap().values().iterator().next().getName()).op())
@@ -489,20 +486,20 @@ private static SavedModelBundle fromHandle(
 
         }
       }
-
-      // try to do the unwrapping based on name if there are no outputs (and thus we can't find the call op)
-      if (!functions.containsKey(signatureName) && signatureDef.getOutputsCount() < 1) {
-        for (ConcreteFunction fn : graphFunctions) {
-          Matcher matcher = INFERENCE_FUNCTION_NAME_PATTERN.matcher(fn.getNativeFunctionName());
-          if (matcher.find()) {
-            String fnName = matcher.group(1);
-            if (fnName.equals(signatureName)) {
-              functions.put(signatureName, fn);
-              break;
-            }
-          }
-        }
-      }
+//
+//      // try to do the unwrapping based on name if there are no outputs (and thus we can't find the call op)
+//      if (!functions.containsKey(signatureName) && signatureDef.getOutputsCount() < 1) {
+//        for (ConcreteFunction fn : graphFunctions) {
+//          Matcher matcher = INFERENCE_FUNCTION_NAME_PATTERN.matcher(fn.getNativeFunctionName());
+//          if (matcher.find()) {
+//            String fnName = matcher.group(1);
+//            if (fnName.equals(signatureName)) {
+//              functions.put(signatureName, fn);
+//              break;
+//            }
+//          }
+//        }
+//      }
 
       // otherwise use the wrapper
       if (!functions.containsKey(signatureName)) {
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java
index 707012e7aaa..dfd34ad3cc1 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java
@@ -52,7 +52,7 @@ public AbstractTF_Function(Pointer p) {
   /**
    * Adds a deallocator if there isn't already one.  Attaches to the current scope regardless.
    */
-  public TF_Function withDeallocator() {
+  public TF_Function withDeallocatorInScope() {
     if (hasDeallocator) {
       Iterator it = PointerScope.getScopeIterator();
       if (it != null) {
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
index e8b08e8f9e9..74132ccb010 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
@@ -40,11 +40,21 @@ private static Signature minusTwo(Ops tf) {
     return Signature.builder().key("minusTwo").input("x", input).output("y", output).build();
   }
 
+  @SuppressWarnings("unchecked")
+  private static Signature plusFiveMinusTwo(Ops tf) {
+    Placeholder input = tf.placeholder(TFloat32.class);
+    try (ConcreteFunction plusFive = ConcreteFunction.create(ConcreteFunctionTest::plusFive);
+        ConcreteFunction minusTwo = ConcreteFunction.create(ConcreteFunctionTest::minusTwo)) {
+      Operand result = (Operand) minusTwo.call(tf, plusFive.call(tf, input));
+      return Signature.builder().key("plusFiveMinusTwo").input("x", input).output("y", result).build();
+    }
+  }
+
   @Test
   public void createFunction() {
     try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive);
         TFloat32 x = TFloat32.scalarOf(3.0f)) {
-      assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat());
+      assertEquals(8.0f, ((TFloat32) f.call(x)).getFloat());
     }
   }
 
@@ -97,4 +107,31 @@ public void getGraphFunctions() {
       }
     }
   }
+
+  @Test
+  public void testNestedFunctionEager() {
+    try (EagerSession sess = EagerSession.create();
+        ConcreteFunction function = ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) {
+      Ops tf = Ops.create(sess);
+      Operand a = tf.constant(10f);
+      Operand result = (Operand) function.call(tf, a);
+      try (TFloat32 t = result.asTensor()) {
+        assertEquals(13f, t.getFloat());
+      }
+    }
+  }
+
+  @Test
+  public void testNestedFunctionGraph() {
+    try (Graph graph = new Graph();
+        ConcreteFunction function = ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) {
+      Ops tf = Ops.create(graph);
+      Operand a = tf.constant(10f);
+      Operand result = (Operand) function.call(tf, a);
+      try (Session sess = new Session(graph);
+          TFloat32 t = (TFloat32) sess.runner().fetch(result).run().get(0)) {
+        assertEquals(13f, t.getFloat());
+      }
+    }
+  }
 }

From bafcec23509d3a8c134c44d1aec3e9e238d73b9d Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Sat, 6 Mar 2021 15:43:12 -0800
Subject: [PATCH 11/34] Remove unwrapping

Signed-off-by: Ryan Nett 
---
 .../java/org/tensorflow/SavedModelBundle.java | 100 ------------------
 1 file changed, 100 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
index fb392d195a9..869dc7a6d21 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -17,7 +17,6 @@
 
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrValueProto;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
 
 import com.google.protobuf.InvalidProtocolBufferException;
@@ -32,7 +31,6 @@
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BytePointer;
 import org.bytedeco.javacpp.PointerPointer;
@@ -40,18 +38,14 @@
 import org.tensorflow.exceptions.TensorFlowException;
 import org.tensorflow.internal.c_api.TF_Buffer;
 import org.tensorflow.internal.c_api.TF_Graph;
-import org.tensorflow.internal.c_api.TF_Operation;
 import org.tensorflow.internal.c_api.TF_Session;
 import org.tensorflow.internal.c_api.TF_SessionOptions;
 import org.tensorflow.internal.c_api.TF_Status;
-import org.tensorflow.proto.framework.AttrValue;
 import org.tensorflow.proto.framework.ConfigProto;
 import org.tensorflow.proto.framework.MetaGraphDef;
 import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
 import org.tensorflow.proto.framework.RunOptions;
 import org.tensorflow.proto.framework.SavedModel;
-import org.tensorflow.proto.framework.SignatureDef;
-import org.tensorflow.proto.framework.TensorInfo;
 import org.tensorflow.proto.util.SaverDef;
 
 /**
@@ -387,53 +381,6 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef
     this.functions = functions;
   }
 
-  private static final Pattern INFERENCE_FUNCTION_NAME_PATTERN = Pattern
-      .compile("__inference_(.+)_\\d+", Pattern.DOTALL);
-
-  /**
-   * Check that all outputs of the signature come from a single call op that takes the inputs.
-   */
-  private static GraphOperation findFunctionWrapper(Graph graph, SignatureDef signatureDef) {
-
-    GraphOperation callOp = null;
-    for (TensorInfo output : signatureDef.getOutputsMap().values()) {
-      GraphOperation op = (GraphOperation) graph.outputOrError(output.getName()).op();
-      if (callOp == null) {
-        callOp = op;
-      } else if (!callOp.equals(op)) {
-        return null;
-      }
-    }
-
-    if (callOp == null) {
-      return null;
-    }
-
-    if (callOp != null) {
-
-      if (callOp.numInputs() != signatureDef.getInputsCount() || callOp.numOutputs() != signatureDef
-          .getOutputsCount()) {
-        return null;
-      }
-
-      int i = 0;
-      List> opInputs = callOp.inputs();
-
-      for (TensorInfo input : signatureDef.getInputsMap().values()) {
-        if (!graph.outputOrError(input.getName()).equals(opInputs.get(i))) {
-          return null;
-        }
-        i++;
-      }
-    }
-
-    if (!callOp.type().equals(ConcreteFunction.CALL_OP) && !callOp.type().equals(ConcreteFunction.STATEFUL_CALL_OP)) {
-      return null;
-    }
-
-    return callOp;
-  }
-
   /**
    * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the
    * MetaGraphDef.
@@ -454,53 +401,6 @@ private static SavedModelBundle fromHandle(
     List graphFunctions = graph.getFunctions();
     metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> {
 
-      GraphOperation callOp = findFunctionWrapper(graph, signatureDef);
-
-      // if the function is a thin wrapper around a function call, unwrap it
-      if (callOp != null) {
-
-        try (PointerScope scope = new PointerScope()) {
-          TF_Operation op = ((GraphOperation) graph
-              .outputOrError(signatureDef.getOutputsMap().values().iterator().next().getName()).op())
-              .getUnsafeNativeHandle();
-          TF_Status status = TF_Status.newStatus();
-          TF_Buffer buff = TF_Buffer.newBuffer();
-          TF_OperationGetAttrValueProto(op, "f", buff, status);
-          status.throwExceptionIfNotOK();
-          AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
-
-          String functionName = def.getFunc().getName();
-
-          ConcreteFunction function = null;
-          for (ConcreteFunction fn : graphFunctions) {
-            if (fn.getNativeFunctionName().equals(functionName)) {
-              function = fn;
-              break;
-            }
-          }
-
-          if (function != null) {
-            functions.put(signatureName, function.withNewSignature(new Signature(signatureName, signatureDef)));
-          }
-        } catch (InvalidProtocolBufferException | IllegalArgumentException ignored) {
-
-        }
-      }
-//
-//      // try to do the unwrapping based on name if there are no outputs (and thus we can't find the call op)
-//      if (!functions.containsKey(signatureName) && signatureDef.getOutputsCount() < 1) {
-//        for (ConcreteFunction fn : graphFunctions) {
-//          Matcher matcher = INFERENCE_FUNCTION_NAME_PATTERN.matcher(fn.getNativeFunctionName());
-//          if (matcher.find()) {
-//            String fnName = matcher.group(1);
-//            if (fnName.equals(signatureName)) {
-//              functions.put(signatureName, fn);
-//              break;
-//            }
-//          }
-//        }
-//      }
-
       // otherwise use the wrapper
       if (!functions.containsKey(signatureName)) {
         Signature signature = new Signature(signatureName, signatureDef);

From bb641e343cdecec226d2bef955b2a1cd146d19db Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Sat, 6 Mar 2021 19:45:18 -0800
Subject: [PATCH 12/34] Proper attribute setters

Signed-off-by: Ryan Nett 
---
 .../java/org/tensorflow/ConcreteFunction.java |  2 +-
 .../org/tensorflow/EagerOperationBuilder.java | 40 ++++++++++++++++--
 .../org/tensorflow/GraphOperationBuilder.java | 42 ++++++++++++++++++-
 .../java/org/tensorflow/OperationBuilder.java | 19 +++++++--
 .../tensorflow/EagerOperationBuilderTest.java |  5 ++-
 .../tensorflow/GraphOperationBuilderTest.java |  7 ++--
 6 files changed, 100 insertions(+), 15 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
index 8ab491c8d57..282ba51a02f 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
@@ -238,7 +238,7 @@ public Map> call(Scope scope,
 
     opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new));
 
-    opBuilder.setFunctionName("f", name);
+    opBuilder.setAttr("f", this);
     opBuilder.setAttr("Tin", inputList.stream().map(x -> x.asOutput().dataType()).toArray(DataType[]::new));
     opBuilder.setAttr("Tout", signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new));
 
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java
index 87bd6ee21c1..98bc59abaaa 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java
@@ -22,6 +22,7 @@
 import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrBoolList;
 import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloat;
 import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFloatList;
+import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionList;
 import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrFunctionName;
 import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrInt;
 import static org.tensorflow.internal.c_api.global.tensorflow.TFE_OpSetAttrIntList;
@@ -36,6 +37,9 @@
 
 import java.nio.charset.Charset;
 import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BooleanPointer;
 import org.bytedeco.javacpp.BytePointer;
 import org.bytedeco.javacpp.IntPointer;
@@ -219,8 +223,22 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) {
   }
 
   @Override
-  public OperationBuilder setFunctionName(String attrName, String functionName) {
-    setAttrFunctionName(opHandle, attrName, functionName);
+  public OperationBuilder setAttr(String name, ConcreteFunction value) {
+    session.attachFunction(value);
+    setAttrFunctionName(opHandle, name, value.getNativeFunctionName());
+    return this;
+  }
+
+  @Override
+  public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
+    for (ConcreteFunction fn : value) {
+      session.attachFunction(fn);
+    }
+
+    setAttrFunctionList(opHandle, session.nativeHandle(), name, Arrays.stream(value)
+        .map(ConcreteFunction::getNativeFunctionName)
+        .collect(Collectors.toList()));
+
     return this;
   }
 
@@ -416,7 +434,7 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes
       }
       TF_Status status = TF_Status.newStatus();
       TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims),
-              numDims.length, status);
+          numDims.length, status);
     }
   }
 
@@ -426,4 +444,20 @@ private static void setAttrFunctionName(TFE_Op opHandle, String attrName, String
       TFE_OpSetAttrFunctionName(opHandle, attrName, functionName, functionName.length());
     }
   }
+
+  private static void setAttrFunctionList(TFE_Op opHandle, TFE_Context context, String attrName,
+      List functionNames) {
+    requireOp(opHandle);
+    requireContext(context);
+    try (PointerScope scope = new PointerScope()) {
+      PointerPointer fns = new PointerPointer<>(functionNames.size());
+      for (int i = 0; i < functionNames.size(); i++) {
+        TF_Status status = TF_Status.newStatus();
+        TFE_Op op = TFE_Op.newOp(context, functionNames.get(i), status);
+        status.throwExceptionIfNotOK();
+        fns.put(i, op);
+      }
+      TFE_OpSetAttrFunctionList(opHandle, new BytePointer(attrName), fns, functionNames.size());
+    }
+  }
 }
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java
index 3d638069fee..c8ec73c0346 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java
@@ -35,9 +35,13 @@
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTensorList;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrType;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrTypeList;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrValueProto;
 import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetDevice;
 
 import java.nio.charset.Charset;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BooleanPointer;
 import org.bytedeco.javacpp.BytePointer;
 import org.bytedeco.javacpp.IntPointer;
@@ -54,7 +58,10 @@
 import org.tensorflow.internal.c_api.TF_Status;
 import org.tensorflow.internal.c_api.TF_Tensor;
 import org.tensorflow.ndarray.Shape;
+import org.tensorflow.proto.framework.AttrValue;
+import org.tensorflow.proto.framework.AttrValue.ListValue;
 import org.tensorflow.proto.framework.DataType;
+import org.tensorflow.proto.framework.NameAttrList;
 
 /**
  * An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}.
@@ -347,9 +354,24 @@ public GraphOperationBuilder setAttr(String name, String[] value) {
   }
 
   @Override
-  public OperationBuilder setFunctionName(String attrName, String functionName) {
+  public OperationBuilder setAttr(String name, ConcreteFunction value) {
+    graph.attachFunction(value);
     try (Reference r = graph.ref()) {
-      setAttrFunctionName(unsafeNativeHandle, attrName, functionName);
+      setAttrFunctionName(unsafeNativeHandle, name, value.getNativeFunctionName());
+    }
+    return this;
+  }
+
+  @Override
+  public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
+    for (ConcreteFunction f : value) {
+      graph.attachFunction(f);
+    }
+
+    try (Reference r = graph.ref()) {
+      setAttrFunctionList(unsafeNativeHandle, name, Arrays.stream(value)
+          .map(ConcreteFunction::getNativeFunctionName)
+          .collect(Collectors.toList()));
     }
     return this;
   }
@@ -556,4 +578,20 @@ private static void setAttrFunctionName(TF_OperationDescription opHandle, String
       TF_SetAttrFuncName(opHandle, attrName, functionName, functionName.length());
     }
   }
+
+  private static void setAttrFunctionList(TF_OperationDescription opHandle, String attrName,
+      List functionNames) {
+    requireHandle(opHandle);
+    try (PointerScope scope = new PointerScope()) {
+      TF_Status status = TF_Status.newStatus();
+      AttrValue value = AttrValue.newBuilder().setList(ListValue.newBuilder().addAllFunc(
+          functionNames.stream()
+              .map(x -> NameAttrList.newBuilder().setName(x).build())
+              .collect(Collectors.toList())
+      ).build()).build();
+      byte[] bytes = value.toByteArray();
+      TF_SetAttrValueProto(opHandle, attrName, new BytePointer(bytes), bytes.length, status);
+      status.throwExceptionIfNotOK();
+    }
+  }
 }
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java
index dca709389f3..e09de39b6c6 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java
@@ -226,11 +226,22 @@ public interface OperationBuilder {
   OperationBuilder setAttr(String name, Shape[] value);
 
   /**
-   * Set a function name attribute of the operation being build.
+   * Set the function value of an attribute of the operation being built. Also attaches the function and dependencies to
+   * the execution environment.
    *
-   * @param attrName the attribute to set
-   * @param functionName the function name
+   * @param name attribute name
+   * @param value attribute value
+   * @return the OperationBuilder instance for chaining.
+   */
+  OperationBuilder setAttr(String name, ConcreteFunction value);
+
+  /**
+   * Set the function values of an attribute of the operation being built. Also attaches the functions and dependencies
+   * to the execution environment.
+   *
+   * @param name attribute name
+   * @param value attribute value
    * @return the OperationBuilder instance for chaining.
    */
-  OperationBuilder setFunctionName(String attrName, String functionName);
+  OperationBuilder setAttr(String name, ConcreteFunction[] value);
 }
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
index b39ecec9881..8a3d56bd37f 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java
@@ -124,7 +124,7 @@ public void setAttrs() {
           .build();
       // bool
       opBuilder(session, "All", "Bool")
-          .addInput(tf.constant(new boolean[] {true, true, false}).asOutput())
+          .addInput(tf.constant(new boolean[]{true, true, false}).asOutput())
           .addInput(tf.constant(0).asOutput())
           .setAttr("keep_dims", false)
           .build();
@@ -134,7 +134,8 @@ public void setAttrs() {
           .addInput(tf.constant(10.00000f).asOutput())
           .setAttr("tolerance", 0.1f)
           .build();
-      // Missing tests: list(string), list(byte), list(bool), list(type)
+      // Missing tests: list(string), list(byte), list(bool), list(type), list(func)
+      // func is done via ConcreteFunction execution
     }
   }
 
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java
index 33ae979ccbd..66ba2122501 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java
@@ -96,17 +96,18 @@ public void setAttr() {
       g.opBuilder("MaxPool", "IntList")
           .addInput(tf.constant(new float[2][2][2][2]).asOutput())
           .setAttr("ksize", new long[] {1, 1, 1, 1})
-          .setAttr("strides", new long[] {1, 1, 1, 1})
+          .setAttr("strides", new long[]{1, 1, 1, 1})
           .setAttr("padding", "SAME")
           .build();
       assertTrue(hasNode(g, "IntList"));
       // list(float)
       g.opBuilder("FractionalMaxPool", "FloatList")
           .addInput(tf.constant(new float[2][2][2][2]).asOutput())
-          .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
+          .setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f})
           .build();
       assertTrue(hasNode(g, "FloatList"));
-      // Missing tests: float, list(dtype), list(tensor), list(string), list(bool)
+      // Missing tests: float, list(dtype), list(tensor), list(string), list(bool), list(func)
+      // func is done via ConcreteFunction execution
     }
   }
 

From 9e686c41b7f1c3e3e7e67286ed7afd38566d65a9 Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Sat, 6 Mar 2021 20:40:44 -0800
Subject: [PATCH 13/34] Add ignored gradient test

Signed-off-by: Ryan Nett 
---
 .../java/org/tensorflow/ConcreteFunction.java |  9 ----
 .../src/main/java/org/tensorflow/Graph.java   |  8 +--
 .../org/tensorflow/ConcreteFunctionTest.java  | 52 +++++++++++++++++++
 3 files changed, 56 insertions(+), 13 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
index 282ba51a02f..4e8da0230a5 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
@@ -406,15 +406,6 @@ TF_Function nativeHandle() {
     return nativeFunction.getNativeHandle();
   }
 
-  /**
-   * Get the native handle of the function's gradient, so that it can be attached to a Graph.  Not implemented yet.
-   *
-   * TODO implement
-   */
-  TF_Function gradNativeHandle() {
-    return null;
-  }
-
   private final Signature signature;
   private final NativeFunction nativeFunction;
   private final PointerScope scope;
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
index b2160dcf2d7..6482889ccab 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -387,14 +387,14 @@ public GraphOperationBuilder opBuilder(String type, String name) {
   public void attachFunction(ConcreteFunction function) {
     try (Reference ref = ref();
         PointerScope scope = new PointerScope()) {
-      attachNativeFunction(ref.nativeHandle(), function.nativeHandle(), function.gradNativeHandle());
-      function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x, null));
+      attachNativeFunction(ref.nativeHandle(), function.nativeHandle());
+      function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x));
     }
   }
 
-  private void attachNativeFunction(TF_Graph graph, TF_Function fn, TF_Function grad) {
+  private void attachNativeFunction(TF_Graph graph, TF_Function fn) {
     TF_Status status = TF_Status.newStatus();
-    TF_GraphCopyFunction(graph, fn, grad, status);
+    TF_GraphCopyFunction(graph, fn, null, status);
     status.throwExceptionIfNotOK();
   }
 
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
index 74132ccb010..5eed335bf32 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
@@ -17,12 +17,14 @@
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 
+import java.util.Arrays;
 import org.junit.jupiter.api.Test;
 import org.tensorflow.op.Ops;
 import org.tensorflow.op.core.Init;
 import org.tensorflow.op.core.Placeholder;
 import org.tensorflow.op.math.Add;
 import org.tensorflow.op.math.Sub;
+import org.tensorflow.proto.framework.DataType;
 import org.tensorflow.types.TFloat32;
 
 public class ConcreteFunctionTest {
@@ -134,4 +136,54 @@ public void testNestedFunctionGraph() {
       }
     }
   }
+
+  private static Signature square(Ops tf) {
+    Placeholder input = tf.placeholder(TFloat32.class);
+    Operand output = tf.math.square(input);
+    return Signature.builder().methodName("square").key("square").input("x", input).output("y", output).build();
+  }
+
+  // call op gradients are not defined in c++
+//  @Test
+  public void testGradientsGraph() {
+    try (Graph g = new Graph();
+        ConcreteFunction square = ConcreteFunction.create(ConcreteFunctionTest::square);
+        Session s = new Session(g)) {
+      Ops tf = Ops.create(g);
+
+      Output x1 = tf.placeholder(TFloat32.class).output();
+      Output x2 = tf.placeholder(TFloat32.class).output();
+      Output y0 = (Output) square.call(tf, x1);
+      Output y1 = (Output) square.call(tf, y0);
+      Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum();
+
+      Output[] grads0 = g.addGradients(y1, new Output[]{x1});
+      assertNotNull(grads0);
+      assertEquals(1, grads0.length);
+      assertEquals(DataType.DT_FLOAT, grads0[0].dataType());
+
+      Output[] grads1 = g.addGradients(y2, new Output[]{x1, x2});
+      assertNotNull(grads1);
+      assertEquals(2, grads1.length);
+      assertEquals(DataType.DT_FLOAT, grads1[0].dataType());
+      assertEquals(DataType.DT_FLOAT, grads1[1].dataType());
+
+      try (TFloat32 c1 = TFloat32.scalarOf(3.0f);
+          TFloat32 c2 = TFloat32.scalarOf(2.0f);
+          AutoCloseableList outputs = new AutoCloseableList<>(
+              s.runner()
+                  .feed(x1, c1)
+                  .feed(x2, c2)
+                  .fetch(grads0[0])
+                  .fetch(grads1[0])
+                  .fetch(grads1[1])
+                  .run())) {
+
+        assertEquals(3, outputs.size());
+        assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
+        assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
+        assertEquals(1.0f, ((TFloat32) outputs.get(2)).getFloat(), 0.0f);
+      }
+    }
+  }
 }

From b4bf605a263da1f0b9dbf2be874213e775805411 Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Mon, 8 Mar 2021 19:55:03 -0800
Subject: [PATCH 14/34] Rebase fix

Signed-off-by: Ryan Nett 
---
 .../src/main/java/org/tensorflow/ConcreteFunction.java        | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
index 4e8da0230a5..78e667a223b 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
@@ -557,11 +557,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
       TF_Status status = TF_Status.newStatus();
 
       List> inputs = signature.getInputs().values().stream()
-          .map((x) -> graph.outputOrError(x.name))
+          .map((x) -> graph.outputOrThrow(x.name))
           .collect(Collectors.toList());
 
       List> outputs = signature.getOutputs().values().stream()
-          .map((x) -> graph.outputOrError(x.name))
+          .map((x) -> graph.outputOrThrow(x.name))
           .collect(Collectors.toList());
 
       List ops = new ArrayList<>(

From b7ab76cbc7b5fb65c4c1f3b1461fb5b073f1b388 Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Sat, 13 Mar 2021 11:32:35 -0800
Subject: [PATCH 15/34] Op generation for functions

Signed-off-by: Ryan Nett 
---
 .../internal/c_api/TF_Function.java           | 31 +++++++------------
 1 file changed, 11 insertions(+), 20 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java
index df77798beeb..829d1cede3c 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Function.java
@@ -2,29 +2,20 @@
 
 package org.tensorflow.internal.c_api;
 
-import org.bytedeco.javacpp.Pointer;
-import org.bytedeco.javacpp.annotation.Opaque;
-import org.bytedeco.javacpp.annotation.Properties;
+import java.nio.*;
+import org.bytedeco.javacpp.*;
+import org.bytedeco.javacpp.annotation.*;
+
+import static org.tensorflow.internal.c_api.global.tensorflow.*;
 
 
 // TF_Function is a grouping of operations with defined inputs and outputs.
 // Once created and added to graphs, functions can be invoked by creating an
 // operation whose operation type matches the function name.
-@Opaque
-@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
+@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
 public class TF_Function extends org.tensorflow.internal.c_api.AbstractTF_Function {
-
-    /**
-     * Empty constructor. Calls {@code super((Pointer)null)}.
-     */
-    public TF_Function() {
-        super((Pointer) null);
-    }
-
-    /**
-     * Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
-     */
-    public TF_Function(Pointer p) {
-        super(p);
-    }
-}
\ No newline at end of file
+    /** Empty constructor. Calls {@code super((Pointer)null)}. */
+    public TF_Function() { super((Pointer)null); }
+    /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
+    public TF_Function(Pointer p) { super(p); }
+}

From 6d8308e57b13baa98e1754b3e30c22bfaf55e541 Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Wed, 17 Mar 2021 13:58:09 -0700
Subject: [PATCH 16/34] Rebase fix

Signed-off-by: Ryan Nett 
---
 .../main/java/org/tensorflow/ConcreteFunction.java    | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
index 78e667a223b..d97bf944ed2 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java
@@ -41,6 +41,7 @@
 import org.tensorflow.op.Ops;
 import org.tensorflow.op.Scope;
 import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.PlaceholderWithDefault;
 import org.tensorflow.proto.framework.AttrValue;
 import org.tensorflow.proto.framework.DataType;
 import org.tensorflow.proto.framework.FunctionDef;
@@ -565,11 +566,17 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
           .collect(Collectors.toList());
 
       List ops = new ArrayList<>(
-          graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs),
-              null, Collections.singleton(Placeholder.OP_NAME)));
+          graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs)));
 
       inputs.forEach(input -> ops.remove(input.op()));
 
+      ops.forEach(x -> {
+        if(x.type().equals(Placeholder.OP_NAME) || x.type().equals(PlaceholderWithDefault.OP_NAME)){
+          throw new IllegalArgumentException("Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), "
+              + "they also depend on \"" + x + "\"");
+        }
+      });
+
       // Python sometimes has NoOps as outputs
       Ops tf = Ops.create(graph).withSubScope("functionControlOutputs");
       for (int i = 0; i < outputs.size(); i++) {

From eaeaf6ec6f076f443f7adcebee8843bcf15bb462 Mon Sep 17 00:00:00 2001
From: Ryan Nett 
Date: Sun, 11 Apr 2021 16:11:44 -0700
Subject: [PATCH 17/34] SavedFunction for running functions from
 SavedModelBundles

Signed-off-by: Ryan Nett 
---
 .../java/org/tensorflow/SavedModelBundle.java | 122 ++++++++++++++++--
 .../org/tensorflow/SavedModelBundleTest.java  |  10 +-
 2 files changed, 116 insertions(+), 16 deletions(-)

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
index 869dc7a6d21..1b0b47c5455 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -31,6 +31,7 @@
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BytePointer;
 import org.bytedeco.javacpp.PointerPointer;
@@ -223,6 +224,104 @@ public void export() throws IOException {
     private final Map functions = new LinkedHashMap<>();
   }
 
+  /**
+   * A function loaded from a saved model.  It can be called using the saved model's session.
+   *
+   * All resources are owned by the SavedModel.
+   *
+   * TODO initializing the session
+   */
+  public final class SavedFunction {
+
+    private final Signature signature;
+
+    private SavedFunction(Signature signature) {
+      this.signature = signature;
+    }
+
+    /**
+     * The signature of the function.
+     */
+    public Signature signature() {
+      return signature;
+    }
+
+    /**
+     * The name of the function.
+     */
+    public String name() {
+      return signature.key();
+    }
+
+    /**
+     * Call this function using the SavedModel's session.
+     *
+     * 

Caller is responsible for closing all returned Tensors. + * + * @throws IllegalArgumentException if an argument is missing, an argument is passed for an unknown parameter, + * or an argument has the wrong type + */ + public Map call(Map arguments) { + Session.Runner runner = session.runner(); + arguments.forEach((name, value) -> { + Signature.TensorDescription parameter = signature.getInputs().get(name); + if (parameter == null) { + throw new IllegalArgumentException("Function \"" + name() + "\" has no argument \"" + name + "\"."); + } + if (value.dataType() != parameter.dataType) { + throw new IllegalArgumentException("Function \"" + name() + "\"'s argument \"" + name + + "\" has data type " + parameter.dataType + ", but a tensor of data type " + value.dataType() + + " was passed."); + } + runner.feed(parameter.name, value); + }); + + signature.inputNames().forEach((param) -> { + if (!arguments.containsKey(param)) { + throw new IllegalArgumentException( + "Function \"" + name() + "\" has a parameter \"" + param + "\", but no argument was passed for it."); + } + }); + + List resultNames = new ArrayList<>(signature.getOutputs().size()); + signature.getOutputs().forEach((name, desc) -> { + runner.fetch(desc.name); + resultNames.add(name); + }); + + List result = runner.run(); + Map namedResults = new LinkedHashMap<>(result.size()); + + for (int i = 0; i < result.size(); i++) { + namedResults.put(resultNames.get(i), result.get(i)); + } + return namedResults; + } + + + /** + * Call this single-argument single-result function using the SavedModel's session. + * + *

Caller is responsible for closing the returned Tensor. + * + * @throws IllegalStateException if this function does not have exactly one input and output. + */ + public Tensor call(Tensor argument) { + if (signature.getInputs().size() != 1) { + throw new IllegalStateException("Can only use this call method on functions with exactly one input, function \"" + + name() + "\" has " + signature.getInputs().size() + "."); + } + if (signature.getOutputs().size() != 1) { + throw new IllegalStateException("Can only use this call method on functions with exactly one input, function \"" + + name() + "\" has " + signature.getInputs().size() + "."); + } + Map inputMap = new LinkedHashMap<>(1); + inputMap.put(signature.inputNames().iterator().next(), argument); + Map results = call(inputMap); + return results.get(signature.outputNames().iterator().next()); + } + } + /** * Load a saved model from an export directory. The model that is being loaded should be created using the Saved Model API. @@ -314,8 +413,8 @@ public List signatures() { * @return object that can be used to make calls to a function * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ - public ConcreteFunction function(String signatureKey) { - ConcreteFunction function = functions.get(signatureKey); + public SavedFunction function(String signatureKey) { + SavedFunction function = functions.get(signatureKey); if (function == null) { throw new IllegalArgumentException( String.format("Function with signature [%s] not found", signatureKey)); @@ -326,7 +425,7 @@ public ConcreteFunction function(String signatureKey) { /** * Get all functions in the bundle. */ - public List functions() { + public List functions() { return new ArrayList<>(functions.values()); } @@ -347,7 +446,7 @@ public List functions() { * @throws IllegalArgumentException if no function can be selected by default */ public Map call(Map arguments) { - ConcreteFunction function = null; + SavedFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); } else { @@ -371,14 +470,15 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; - private final Map functions; + private final Map functions; private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, - Map functions) { + Map signatures) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; - this.functions = functions; + this.functions = signatures.entrySet().stream() + .collect(Collectors.toMap(Entry::getKey, e -> new SavedFunction(e.getValue()))); } /** @@ -397,14 +497,12 @@ private static SavedModelBundle fromHandle( // Note that the saved model will remain the owner of the graph and the session, meaning // that the functions do not need to be closed by the user and if it does, it should have // no effect. - final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); - List graphFunctions = graph.getFunctions(); - metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { + final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); - // otherwise use the wrapper + metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { if (!functions.containsKey(signatureName)) { Signature signature = new Signature(signatureName, signatureDef); - functions.put(signatureName, ConcreteFunction.create(signature, session)); + functions.put(signatureName, signature); } }); return new SavedModelBundle(graph, session, metaGraphDef, functions); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index f93826c6092..24ec48a1cbd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; +import org.tensorflow.SavedModelBundle.SavedFunction; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; @@ -128,7 +129,7 @@ public void exportFunctionWithVariables() throws IOException { assertEquals(Signature.DEFAULT_KEY, savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); - ConcreteFunction function = savedModel.function(Signature.DEFAULT_KEY); + SavedFunction function = savedModel.function(Signature.DEFAULT_KEY); assertNotNull(function); Signature signature = function.signature(); @@ -191,13 +192,13 @@ public void exportMultipleFunctions() throws IOException { } try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { assertEquals(2, model.signatures().size()); - ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); + SavedFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); TFloat32 t = (TFloat32)f1.call(x)) { assertEquals(reducedSum, t.getFloat(), EPSILON); } - ConcreteFunction f2 = model.function("identity"); + SavedFunction f2 = model.function("identity"); assertNotNull(f2); try (TFloat32 x = TFloat32.scalarOf(10.0f); TFloat32 t = (TFloat32)f2.call(x)) { @@ -266,10 +267,11 @@ public void pythonTfFunction() { * Test model was created in python * Signature name used for saving 'add', argument names 'a' and 'b' */ - ConcreteFunction add = bundle.function("add"); + SavedFunction add = bundle.function("add"); Map args = new HashMap<>(); try (TFloat32 a = TFloat32.scalarOf(10.0f); TFloat32 b = TFloat32.scalarOf(15.5f)) { + System.out.println(add.signature()); args.put("a", a); args.put("b", b); Map result = add.call(args); From f32fbf2afa5d231414491d7ac4696765fc3ed422 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 17 Apr 2021 16:32:38 -0700 Subject: [PATCH 18/34] Review fixes Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 167 ++++++++++-------- .../java/org/tensorflow/EagerSession.java | 18 +- .../src/main/java/org/tensorflow/Graph.java | 31 ++-- .../java/org/tensorflow/NativeFunction.java | 16 +- .../java/org/tensorflow/SavedModelBundle.java | 12 +- .../src/main/java/org/tensorflow/Session.java | 5 +- .../main/java/org/tensorflow/TensorFlow.java | 2 +- .../java/org/tensorflow/op/core/Function.java | 20 +++ 8 files changed, 155 insertions(+), 116 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index d97bf944ed2..04c854ffa90 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -212,13 +213,27 @@ public String toString() { return signature.toString(); } + //TODO migrate to the actual ops once they are generated public static final String CALL_OP = "PartitionedCall"; + //TODO migrate to the actual ops once they are generated public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall"; + + /** + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * inputs and outputs are keyed by the names set in the {@code Signature}. + * + * @param scope the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function + */ public Map> call(Scope scope, Map> arguments) { List> inputList = new ArrayList<>(); + Output[] inputs = new Output[signature().inputNames().size()]; + + int i = 0; for (String inputName : signature().inputNames()) { Operand input = arguments.get(inputName); if (input == null) { @@ -226,7 +241,8 @@ public Map> call(Scope scope, "Function " + signature().methodName() + " has parameter \"" + inputName + "\", but no argument was passed for it."); } - inputList.add(input); + inputs[i] = input.asOutput(); + i++; } scope.env().attachFunction(this); @@ -237,11 +253,11 @@ public Map> call(Scope scope, OperationBuilder opBuilder = scope.env() .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName)); - opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new)); + opBuilder.addInputList(inputs); opBuilder.setAttr("f", this); - opBuilder.setAttr("Tin", inputList.stream().map(x -> x.asOutput().dataType()).toArray(DataType[]::new)); - opBuilder.setAttr("Tout", signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new)); + opBuilder.setAttr("Tin", inputDtypes); + opBuilder.setAttr("Tout", outputDtypes); opBuilder = scope.apply(opBuilder); Operation op = opBuilder.build(); @@ -249,14 +265,14 @@ public Map> call(Scope scope, int numOutputs1 = op.numOutputs(); List> outputList = new ArrayList<>(signature().outputNames().size()); - for (int i = 0; i < numOutputs1; i++) { + for (i = 0; i < numOutputs1; i++) { outputList.add(op.output(i)); } Map> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); List outputNames = new ArrayList<>(signature().outputNames()); - for (int i = 0; i < outputNames.size(); i++) { + for (i = 0; i < outputNames.size(); i++) { String outputName = outputNames.get(i); if (i > outputList.size()) { @@ -378,28 +394,6 @@ public void save(String exportDir) throws IOException { SavedModelBundle.exporter(exportDir).withFunction(this).export(); } - /** - * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA - * JIT is extremely non-obvious. - * - * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id: - * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). - */ - private void makeJit() { - try (PointerScope scope = new PointerScope()) { - byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray(); - BytePointer trueValue = new BytePointer(bytes); - - TF_Status status1 = TF_Status.newStatus(); - TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1); - status1.throwExceptionIfNotOK(); - - TF_Status status2 = TF_Status.newStatus(); - TF_FunctionSetAttrValueProto(nativeHandle(), "_noinline", trueValue, bytes.length, status2); - status2.throwExceptionIfNotOK(); - } - } - TF_Function nativeHandle() { if (nativeFunction.getNativeHandle().isNull()) { throw new IllegalStateException("Function has been closed"); @@ -407,32 +401,53 @@ TF_Function nativeHandle() { return nativeFunction.getNativeHandle(); } - private final Signature signature; - private final NativeFunction nativeFunction; - private final PointerScope scope; - private final Set dependencies; - ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection availableFunctions) { this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions)); } - private static boolean dataTypesMatch(List a, List b) { - if (a.size() != b.size()) { - return false; + /** + * Detects the signature from the handle + */ + static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction, + Collection availableFunctions) { + + Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName()) + .key(nativeFunction.getName()); + + for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) { + TensorInfo info = TensorInfo.newBuilder() + .setDtype(input.getType()) + .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) + .setName(input.getName()) + .build(); + + builder.input(input.getName(), info); } - for (int i = 0; i < a.size(); i++) { - DataType aType = a.get(i); - DataType bType = b.get(i); + for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) { + TensorInfo info = TensorInfo.newBuilder() + .setDtype(outputDef.getType()) + .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) + .setName(outputDef.getName()) + .build(); - if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) { - return false; - } + builder.output(outputDef.getName(), info); } - return true; + return new ConcreteFunction( + builder.build(), + nativeFunction, + availableFunctions + ); } + private final Signature signature; + private final NativeFunction nativeFunction; + private final PointerScope scope; + private final Set dependencies; + private final DataType[] inputDtypes; + private final DataType[] outputDtypes; + private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set dependencies) { this.signature = signature; this.nativeFunction = nativeFunction; @@ -452,8 +467,10 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set + this.signature.getOutputs().size()); } - List inputs = this.signature.getInputs().values().stream().map(x -> x.dataType) - .collect(Collectors.toList()); + inputDtypes = this.signature.getInputs().values().stream().map(x -> x.dataType) + .toArray(DataType[]::new); + + List inputs = Arrays.asList(inputDtypes); List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream() .map(ArgDef::getType) .collect(Collectors.toList()); @@ -464,8 +481,9 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set + nativeInputs + ", got " + inputs); } - List outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType) - .collect(Collectors.toList()); + outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); + + List outputs = Arrays.asList(outputDtypes); List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream() .map(ArgDef::getType) .collect(Collectors.toList()); @@ -486,39 +504,42 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set } /** - * Detects the signature from the handle + * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA + * JIT is extremely non-obvious. + * + * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id: + * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */ - static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction, - Collection availableFunctions) { + private void makeJit() { + try (PointerScope scope = new PointerScope()) { + byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray(); + BytePointer trueValue = new BytePointer(bytes); - Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName()) - .key(nativeFunction.getName()); + TF_Status status1 = TF_Status.newStatus(); + TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1); + status1.throwExceptionIfNotOK(); - for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) { - TensorInfo info = TensorInfo.newBuilder() - .setDtype(input.getType()) - .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) - .setName(input.getName()) - .build(); + TF_Status status2 = TF_Status.newStatus(); + TF_FunctionSetAttrValueProto(nativeHandle(), "_noinline", trueValue, bytes.length, status2); + status2.throwExceptionIfNotOK(); + } + } - builder.input(input.getName(), info); + private static boolean dataTypesMatch(List a, List b) { + if (a.size() != b.size()) { + return false; } - for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) { - TensorInfo info = TensorInfo.newBuilder() - .setDtype(outputDef.getType()) - .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build()) - .setName(outputDef.getName()) - .build(); + for (int i = 0; i < a.size(); i++) { + DataType aType = a.get(i); + DataType bType = b.get(i); - builder.output(outputDef.getName(), info); + if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) { + return false; + } } - return new ConcreteFunction( - builder.build(), - nativeFunction, - availableFunctions - ); + return true; } @@ -616,8 +637,4 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) return new ConcreteFunction(signature, new NativeFunction(handle), graph.getNativeFunctions()); } } - - ConcreteFunction withNewSignature(Signature signature) { - return new ConcreteFunction(signature, nativeFunction, dependencies); - } } \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index d0a73bf480e..a2c87285e9a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -28,7 +28,6 @@ import org.tensorflow.internal.WeakPointerScope; import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; -import org.tensorflow.internal.c_api.TF_Function; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; @@ -290,15 +289,16 @@ public OperationBuilder opBuilder(String type, String name) { public void attachFunction(ConcreteFunction function) { checkSession(); try (PointerScope scope = new PointerScope()) { - attachNativeFunction(function.nativeHandle()); - function.getDependencies().forEach(this::attachNativeFunction); - } - } + TF_Status status = TF_Status.newStatus(); + TFE_ContextAddFunction(nativeHandle, function.nativeHandle(), status); + status.throwExceptionIfNotOK(); - private void attachNativeFunction(TF_Function fn) { - TF_Status status = TF_Status.newStatus(); - TFE_ContextAddFunction(nativeHandle, fn, status); - status.throwExceptionIfNotOK(); + function.getDependencies().forEach(fn -> { + TF_Status status2 = TF_Status.newStatus(); + TFE_ContextAddFunction(nativeHandle, fn, status2); + status2.throwExceptionIfNotOK(); + }); + } } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 6482889ccab..32927719e5f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -387,17 +387,21 @@ public GraphOperationBuilder opBuilder(String type, String name) { public void attachFunction(ConcreteFunction function) { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { - attachNativeFunction(ref.nativeHandle(), function.nativeHandle()); - function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x)); - } - } + TF_Status status = TF_Status.newStatus(); + TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), null, status); + status.throwExceptionIfNotOK(); - private void attachNativeFunction(TF_Graph graph, TF_Function fn) { - TF_Status status = TF_Status.newStatus(); - TF_GraphCopyFunction(graph, fn, null, status); - status.throwExceptionIfNotOK(); + function.getDependencies().forEach(x -> { + TF_Status status2 = TF_Status.newStatus(); + TF_GraphCopyFunction(ref.nativeHandle(), x, null, status2); + status2.throwExceptionIfNotOK(); + }); + } } + /** + * Get the graph's functions. Deallocating the function pointers is the caller's responsibility. + */ synchronized List getNativeFunctions() { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { @@ -435,17 +439,14 @@ public synchronized ConcreteFunction getFunction(String key) { // will close unused functions when method ends funcs.forEach(x -> x.getNativeHandle().withDeallocatorInScope()); - ConcreteFunction func = null; + for (NativeFunction f : funcs) { - for (int i = 0; i < funcs.size(); i++) { - - if (funcs.get(i).getName().equals(key) && func == null) { - func = ConcreteFunction.fromNativeHandle(funcs.get(i), funcs); + if (f.getName().equals(key)) { + return ConcreteFunction.fromNativeHandle(f, funcs); } } - - return func; } + return null; } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java index b0405e4cf8e..7fc68fa8133 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java @@ -41,14 +41,6 @@ * A class holding a native function handle and providing cached access to it's {@link FunctionDef}. */ class NativeFunction { - - private final TF_Function nativeHandle; - - private FunctionDef functionDef = null; - private List dependencies = null; - private Boolean stateful = null; - private String name = null; - public NativeFunction(TF_Function nativeHandle) { this.nativeHandle = nativeHandle; } @@ -159,4 +151,12 @@ synchronized Set getAllDependencies(Collection avai .map(NativeFunction::getNativeHandle) .collect(Collectors.toSet()); } + + private final TF_Function nativeHandle; + + private FunctionDef functionDef = null; + private List dependencies = null; + private Boolean stateful = null; + private String name = null; + } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 1b0b47c5455..30ed8038a65 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -233,12 +233,6 @@ public void export() throws IOException { */ public final class SavedFunction { - private final Signature signature; - - private SavedFunction(Signature signature) { - this.signature = signature; - } - /** * The signature of the function. */ @@ -320,6 +314,12 @@ public Tensor call(Tensor argument) { Map results = call(inputMap); return results.get(signature.outputNames().iterator().next()); } + + private final Signature signature; + + private SavedFunction(Signature signature) { + this.signature = signature; + } } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 6890a959f82..51079dbbcea 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -511,9 +511,10 @@ public void run(Op op) { *

This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. */ public void runInit() { - if (!graph.initializers().isEmpty()) { + List initializers = graph.initializers(); + if (!initializers.isEmpty()) { Runner runner = runner(); - graph.initializers().forEach(runner::addTarget); + initializers.forEach(runner::addTarget); runner.run(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index f5edfec09a2..946a02e0b88 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -50,7 +50,7 @@ public static String version() { * @return A OpList protocol * buffer, which lists all the available TensorFlow operations. */ - public static synchronized OpList registeredOpList() { + public static OpList registeredOpList() { TF_Buffer buf = TF_GetAllOpList(); try { return OpList.parseFrom(buf.dataAsByteBuffer()); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java index b38b7e0bbbb..0fe171602e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java @@ -19,6 +19,7 @@ import java.util.Map; import org.tensorflow.ConcreteFunction; import org.tensorflow.Operand; +import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -30,12 +31,31 @@ @Operator(name = "call") public abstract class Function { + /** + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * inputs and outputs are keyed by the names set in the {@code Signature}. + * + * @param scope the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function + * @see ConcreteFunction#call(Ops, Map) + */ @Endpoint public static Map> call(Scope scope, ConcreteFunction function, Map> arguments) { return function.call(scope, arguments); } + + /** + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * works for functions with a single input and output. + * + * @param scope the scope to call the function in + * @param argument the argument to the call + * @return the output of the function + * @see ConcreteFunction#call(Ops, Operand) + */ @Endpoint public static Operand call(Scope scope, ConcreteFunction function, Operand argument) { From f892c546d1cbf5053330821fd5c2f1c18ce456e4 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 17 Apr 2021 16:38:20 -0700 Subject: [PATCH 19/34] Generation and better javadoc Signed-off-by: Ryan Nett --- .../gen/annotations/org/tensorflow/op/Ops.java | 16 ++++++++++++++-- .../java/org/tensorflow/SavedModelBundle.java | 3 ++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 53e87937a96..ec0c10bfd76 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1120,14 +1120,26 @@ public Bucketize bucketize(Operand input, List boundar } /** - * empty + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * works for functions with a single input and output. + * + * @param scope the scope to call the function in + * @param argument the argument to the call + * @return the output of the function + * @see ConcreteFunction#call(Ops, Operand) */ public Operand call(ConcreteFunction function, Operand argument) { return Function.call(scope, function, argument); } /** - * empty + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * inputs and outputs are keyed by the names set in the {@code Signature}. + * + * @param scope the scope to call the function in + * @param arguments the arguments to the call + * @return the outputs of the function + * @see ConcreteFunction#call(Ops, Map) */ public Map> call(ConcreteFunction function, Map> arguments) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 30ed8038a65..d98c1507083 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -229,7 +229,8 @@ public void export() throws IOException { * * All resources are owned by the SavedModel. * - * TODO initializing the session + * The session is not initialized in any way, you can use {@link Session#runInit()} on {@link SavedModelBundle#session()} + * if this is necessary. */ public final class SavedFunction { From f485ccfb93e59fb5531d9d9e684b224a9bb09ab1 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 18 Apr 2021 18:38:56 -0700 Subject: [PATCH 20/34] Rework pointer scopes Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 122 +++++++++++------- .../src/main/java/org/tensorflow/Graph.java | 28 ++-- .../internal/c_api/AbstractTF_Function.java | 27 +--- 3 files changed, 92 insertions(+), 85 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 04c854ffa90..5f5931b7c9c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -72,12 +72,12 @@ public class ConcreteFunction implements AutoCloseable { * Creates a function by building a new graph. * *

The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output - * tensors on execution. + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors + * and fetch the output tensors on execution. * *

The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will - * be freed once the function is discarded. For example: + * the function must be enclosed properly with a try-with-resources block to guarantee that all + * native resources will be freed once the function is discarded. For example: * *

{@code
    * public class MyModel {
@@ -112,8 +112,8 @@ public static ConcreteFunction create(Function functionBuilder)
    * Create a function from a signature and an existing graph.
    *
    * 

The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For - * example: + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the + * function. For example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -130,7 +130,7 @@ public static ConcreteFunction create(Function functionBuilder)
    * }
* * @param signature signature of the function to create - * @param graph a valid and initialized graph + * @param graph a valid and initialized graph * @return a new function */ public static ConcreteFunction create(Signature signature, Graph graph) { @@ -141,8 +141,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) { * Create a function from a signature and a valid graph session. * *

The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For - * example: + * can extend beyond the scope of the function. Therefore the function does not need to be closed + * after its usage. For example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -164,7 +164,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
    * }
* * @param signature signature of the function to create - * @param session a valid session to an initialized graph + * @param session a valid session to an initialized graph * @return a new function */ public static ConcreteFunction create(Signature signature, Session session) { @@ -220,10 +220,10 @@ public String toString() { /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The - * inputs and outputs are keyed by the names set in the {@code Signature}. + * Calls the function in an execution environment, adding it's graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. * - * @param scope the scope to call the function in + * @param scope the scope to call the function in * @param arguments the arguments to the call * @return the outputs of the function */ @@ -276,7 +276,8 @@ public Map> call(Scope scope, String outputName = outputNames.get(i); if (i > outputList.size()) { - throw new IllegalStateException("Somehow, not all required outputs were returned from the function"); + throw new IllegalStateException( + "Somehow, not all required outputs were returned from the function"); } Operand output = outputList.get(i); @@ -287,10 +288,10 @@ public Map> call(Scope scope, } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only - * works for functions with a single input and output. + * Calls the function in an execution environment, adding it's graph as a function if it isn't + * already present. Only works for functions with a single input and output. * - * @param scope the scope to call the function in + * @param scope the scope to call the function in * @param argument the argument to the call * @return the output of the function */ @@ -320,8 +321,10 @@ public Operand call(Scope scope, Operand argument) { * *

Caller is responsible for closing all Tensors. * - * @param arguments list of tensors to pass in input to the function, mapped by their signature name - * @return output tensors resulting from the execution of the function, mapped by their signature name + * @param arguments list of tensors to pass in input to the function, mapped by their signature + * name + * @return output tensors resulting from the execution of the function, mapped by their signature + * name */ public Map call(Map arguments) throws IllegalArgumentException { @@ -348,7 +351,8 @@ public Map call(Map arguments) * * @param tensor input tensor * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function + * @throws IllegalArgumentException if there are multiple input or output parameters defined in + * the function */ public Tensor call(Tensor tensor) throws IllegalArgumentException { Ops tf = Ops.create(); @@ -358,10 +362,10 @@ public Tensor call(Tensor tensor) throws IllegalArgumentException { } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The - * inputs and outputs are keyed by the names set in the {@code Signature}. + * Calls the function in an execution environment, adding it's graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. * - * @param tf the scope to call the function in + * @param tf the scope to call the function in * @param arguments the arguments to the call * @return the outputs of the function */ @@ -370,10 +374,10 @@ public Map> call(Ops tf, Map> arguments) { } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only - * works for functions with a single input and output. + * Calls the function in an execution environment, adding it's graph as a function if it isn't + * already present. Only works for functions with a single input and output. * - * @param tf the scope to call the function in + * @param tf the scope to call the function in * @param argument the argument to the call * @return the output of the function */ @@ -401,17 +405,23 @@ TF_Function nativeHandle() { return nativeFunction.getNativeHandle(); } - ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection availableFunctions) { + /** + * All native functions should have deallocators registered + */ + ConcreteFunction(Signature signature, NativeFunction nativeFunction, + Collection availableFunctions) { this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions)); } /** - * Detects the signature from the handle + * Detects the signature from the handle. Does not close passed functions. All passed functions + * should have deallocators. */ static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction, Collection availableFunctions) { - Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName()) + Signature.Builder builder = Signature.builder() + .methodName(nativeFunction.getFunctionDef().getSignature().getName()) .key(nativeFunction.getName()); for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) { @@ -448,19 +458,26 @@ static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction, private final DataType[] inputDtypes; private final DataType[] outputDtypes; - private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set dependencies) { + + /** + * All native functions should have deallocators registered + */ + private ConcreteFunction(Signature signature, NativeFunction nativeFunction, + Set dependencies) { this.signature = signature; this.nativeFunction = nativeFunction; this.dependencies = Collections.unmodifiableSet(dependencies); - if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) { + if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature() + .getInputArgCount()) { throw new IllegalArgumentException( "Signature must have the same number of inputs as the native function. Expected " + nativeFunction.getFunctionDef().getSignature().getInputArgCount() + ", got " + this.signature.getInputs().size()); } - if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) { + if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature() + .getOutputArgCount()) { throw new IllegalArgumentException( "New signature must have the same number of outputs as the native function. Expected " + nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + ", got " @@ -471,7 +488,8 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set .toArray(DataType[]::new); List inputs = Arrays.asList(inputDtypes); - List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream() + List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList() + .stream() .map(ArgDef::getType) .collect(Collectors.toList()); @@ -481,10 +499,12 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set + nativeInputs + ", got " + inputs); } - outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); + outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType) + .toArray(DataType[]::new); List outputs = Arrays.asList(outputDtypes); - List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream() + List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList() + .stream() .map(ArgDef::getType) .collect(Collectors.toList()); @@ -498,17 +518,17 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set try (PointerScope scope = new PointerScope()) { this.scope = scope; scope.extend(); - this.nativeFunction.getNativeHandle().withDeallocatorInScope(); - this.dependencies.forEach(TF_Function::withDeallocatorInScope); + scope.attach(this.nativeFunction.getNativeHandle()); + this.dependencies.forEach(scope::attach); } } /** - * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA - * JIT is extremely non-obvious. - * - * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id: - * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). + * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because + * how to enable XLA JIT is extremely non-obvious. + *

+ * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered + * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */ private void makeJit() { try (PointerScope scope = new PointerScope()) { @@ -516,7 +536,8 @@ private void makeJit() { BytePointer trueValue = new BytePointer(bytes); TF_Status status1 = TF_Status.newStatus(); - TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1); + TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, + status1); status1.throwExceptionIfNotOK(); TF_Status status2 = TF_Status.newStatus(); @@ -592,9 +613,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) inputs.forEach(input -> ops.remove(input.op())); ops.forEach(x -> { - if(x.type().equals(Placeholder.OP_NAME) || x.type().equals(PlaceholderWithDefault.OP_NAME)){ - throw new IllegalArgumentException("Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), " - + "they also depend on \"" + x + "\""); + if (x.type().equals(Placeholder.OP_NAME) || x.type() + .equals(PlaceholderWithDefault.OP_NAME)) { + throw new IllegalArgumentException( + "Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), " + + "they also depend on \"" + x + "\""); } }); @@ -629,12 +652,15 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) resolveToOutput(graph, outputs), null, null, - new BytePointer(signature.methodName() != null ? signature.methodName() : "Method " + signature.key()), + new BytePointer(signature.methodName() != null ? signature.methodName() + : "Method " + signature.key()), status ); + handle.withDeallocator(); status.throwExceptionIfNotOK(); - return new ConcreteFunction(signature, new NativeFunction(handle), graph.getNativeFunctions()); + return new ConcreteFunction(signature, new NativeFunction(handle), + graph.getNativeFunctions(scope)); } } } \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 32927719e5f..489429ef49a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -400,9 +400,11 @@ public void attachFunction(ConcreteFunction function) { } /** - * Get the graph's functions. Deallocating the function pointers is the caller's responsibility. + * Get the graph's functions. + * + * @param outerScope the pointer scope to attach the functions to. */ - synchronized List getNativeFunctions() { + synchronized List getNativeFunctions(PointerScope outerScope) { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); @@ -418,6 +420,9 @@ synchronized List getNativeFunctions() { for (int i = 0; i < numFunctions; i++) { TF_Function function = output.get(TF_Function.class, i); + function.withDeallocator(); + outerScope.attach(function); + funcs.add(new NativeFunction(function)); } @@ -426,18 +431,17 @@ synchronized List getNativeFunctions() { } /** - * Get the function attached to the graph with the given native name. Returns {@code null} if none found. + * Get the function attached to the graph with the given native name. Returns {@code null} if + * none found. * * @param key the name of the native function. Note that this may include an argument hash. - * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct name + * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct + * name */ public synchronized ConcreteFunction getFunction(String key) { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { - List funcs = getNativeFunctions(); - - // will close unused functions when method ends - funcs.forEach(x -> x.getNativeHandle().withDeallocatorInScope()); + List funcs = getNativeFunctions(scope); for (NativeFunction f : funcs) { @@ -455,10 +459,12 @@ public synchronized ConcreteFunction getFunction(String key) { * @return all functions attached to this graph. */ public synchronized List getFunctions() { - try (Reference ref = ref()) { - List funcs = getNativeFunctions(); + try (Reference ref = ref(); + PointerScope scope = new PointerScope()) { + List funcs = getNativeFunctions(scope); - return funcs.stream().map(x -> ConcreteFunction.fromNativeHandle(x, funcs)).collect(Collectors.toList()); + return funcs.stream().map(x -> ConcreteFunction.fromNativeHandle(x, funcs)) + .collect(Collectors.toList()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java index dfd34ad3cc1..0d021244c6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java @@ -19,9 +19,7 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; -import java.util.Iterator; import org.bytedeco.javacpp.Pointer; -import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.annotation.Properties; @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) @@ -46,30 +44,7 @@ public AbstractTF_Function(Pointer p) { super(p); } - - private boolean hasDeallocator = false; - - /** - * Adds a deallocator if there isn't already one. Attaches to the current scope regardless. - */ - public TF_Function withDeallocatorInScope() { - if (hasDeallocator) { - Iterator it = PointerScope.getScopeIterator(); - if (it != null) { - while (it.hasNext()) { - try { - it.next().attach(this); - } catch (IllegalArgumentException e) { - // try the next scope down the stack - continue; - } - break; - } - } - - return (TF_Function) this; - } - hasDeallocator = true; + public TF_Function withDeallocator() { return this.deallocator(new DeleteDeallocator((TF_Function) this)); } From 5977c4009a3373e89ef80a1b48eee5a5ff22a00d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 14 May 2021 19:24:27 -0700 Subject: [PATCH 21/34] SessionFunction instead of SavedModelBundle specific class Signed-off-by: Ryan Nett --- .../java/org/tensorflow/CallableFunction.java | 97 ++++++++ .../java/org/tensorflow/ConcreteFunction.java | 123 ++++------ .../java/org/tensorflow/SavedModelBundle.java | 212 ++++++------------ .../src/main/java/org/tensorflow/Session.java | 19 ++ .../java/org/tensorflow/SessionFunction.java | 100 +++++++++ .../org/tensorflow/SavedModelBundleTest.java | 180 ++++++++------- 6 files changed, 421 insertions(+), 310 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java new file mode 100644 index 00000000000..7470c0a6a26 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java @@ -0,0 +1,97 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import java.util.LinkedHashMap; +import java.util.Map; +import org.tensorflow.Signature.TensorDescription; + +public interface CallableFunction { + + /** + * Returns the signature of this function + */ + Signature signature(); + + /** + * Invokes a function using the default eager session. + * + *

Caller is responsible for closing all Tensors. + * + * @param arguments list of tensors to pass in input to the function, mapped by their signature name + * @return output tensors resulting from the execution of the function, mapped by their signature name + * @throws IllegalArgumentException if the passed arguments don't match up to the function's parameters. + */ + Map call(Map arguments); + + /** + * Invokes a function with a single input and output using the default eager session. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function + */ + default Tensor call(Tensor tensor) { + if (signature().inputNames().size() > 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with more than one input."); + } + if (signature().inputNames().size() < 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with no inputs."); + } + if (signature().outputNames().size() > 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with more than one output."); + } + if (signature().outputNames().size() < 1) { + throw new IllegalArgumentException( + "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with no outputs."); + } + + String inputName = signature().inputNames().iterator().next(); + String outputName = signature().outputNames().iterator().next(); + + Map inputMap = new LinkedHashMap<>(); + inputMap.put(inputName, tensor); + + return call(inputMap).get(outputName); + } + + static Operand validateDescription(TensorDescription description, Graph graph, String name, String prefix) { + Output operand = graph.output(description.name); + if (operand == null) { + throw new IllegalArgumentException( + prefix + " \"" + name + "\"'s operand \"" + description.name + "\" does not exist on the session's graph."); + } + + if (operand.dataType() != description.dataType) { + throw new IllegalArgumentException( + prefix + " \"" + name + "\"'s operand \"" + description.name + "\" has data type " + operand.dataType() + + " in the session's graph, but the signature requires data type " + description.dataType + "."); + } + + if (!operand.shape().isCompatibleWith(description.shape)) { + throw new IllegalArgumentException( + prefix + " \"" + name + "\"'s operand \"" + description.name + "\" has shape " + operand.shape() + + ", which is incompatible with the signature's required shape of " + description.shape + "."); + } + return operand; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 5f5931b7c9c..a66b66998f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -18,7 +18,6 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -65,19 +64,19 @@ * Map outputTensorMap = myFunction.call(inputTensorMap); * }

*/ -public class ConcreteFunction implements AutoCloseable { +public class ConcreteFunction implements AutoCloseable, CallableFunction { /** * Creates a function by building a new graph. * *

The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors - * and fetch the output tensors on execution. + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output + * tensors on execution. * *

The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that all - * native resources will be freed once the function is discarded. For example: + * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will + * be freed once the function is discarded. For example: * *

{@code
    * public class MyModel {
@@ -112,8 +111,8 @@ public static ConcreteFunction create(Function functionBuilder)
    * Create a function from a signature and an existing graph.
    *
    * 

The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the - * function. For example: + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For + * example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -130,7 +129,7 @@ public static ConcreteFunction create(Function functionBuilder)
    * }
* * @param signature signature of the function to create - * @param graph a valid and initialized graph + * @param graph a valid and initialized graph * @return a new function */ public static ConcreteFunction create(Signature signature, Graph graph) { @@ -141,8 +140,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) { * Create a function from a signature and a valid graph session. * *

The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be closed - * after its usage. For example: + * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For + * example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -164,7 +163,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
    * }
* * @param signature signature of the function to create - * @param session a valid session to an initialized graph + * @param session a valid session to an initialized graph * @return a new function */ public static ConcreteFunction create(Signature signature, Session session) { @@ -174,6 +173,7 @@ public static ConcreteFunction create(Signature signature, Session session) { /** * Returns the signature of this function */ + @Override public Signature signature() { return signature; } @@ -220,10 +220,10 @@ public String toString() { /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't - * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * inputs and outputs are keyed by the names set in the {@code Signature}. * - * @param scope the scope to call the function in + * @param scope the scope to call the function in * @param arguments the arguments to the call * @return the outputs of the function */ @@ -235,12 +235,17 @@ public Map> call(Scope scope, int i = 0; for (String inputName : signature().inputNames()) { - Operand input = arguments.get(inputName); - if (input == null) { + if (!arguments.containsKey(inputName)) { throw new IllegalArgumentException( "Function " + signature().methodName() + " has parameter \"" + inputName + "\", but no argument was passed for it."); } + + Operand input = arguments.get(inputName); + if (input == null) { + throw new IllegalArgumentException( + "Can't pass null as an argument to a function. Argument \"" + inputName + "\" was null."); + } inputs[i] = input.asOutput(); i++; } @@ -288,10 +293,10 @@ public Map> call(Scope scope, } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't - * already present. Only works for functions with a single input and output. + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * works for functions with a single input and output. * - * @param scope the scope to call the function in + * @param scope the scope to call the function in * @param argument the argument to the call * @return the output of the function */ @@ -316,18 +321,8 @@ public Operand call(Scope scope, Operand argument) { return call(scope, inputMap).get(outputName); } - /** - * Invokes a function using the default eager session. - * - *

Caller is responsible for closing all Tensors. - * - * @param arguments list of tensors to pass in input to the function, mapped by their signature - * name - * @return output tensors resulting from the execution of the function, mapped by their signature - * name - */ - public Map call(Map arguments) - throws IllegalArgumentException { + @Override + public Map call(Map arguments) { //FIXME need to manage input/output operand lifetimes Ops tf = Ops.create(); Map> inputs = new LinkedHashMap<>(arguments.size()); @@ -345,27 +340,10 @@ public Map call(Map arguments) } /** - * Invokes a function with a single input and output using the default eager session. + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * inputs and outputs are keyed by the names set in the {@code Signature}. * - *

Caller is responsible for closing all Tensors. - * - * @param tensor input tensor - * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined in - * the function - */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { - Ops tf = Ops.create(); - Operand argument = tf.constantOf((TType) tensor); - Operand output = call(tf, argument); - return output.asTensor(); - } - - /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't - * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. - * - * @param tf the scope to call the function in + * @param tf the scope to call the function in * @param arguments the arguments to the call * @return the outputs of the function */ @@ -374,10 +352,10 @@ public Map> call(Ops tf, Map> arguments) { } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't - * already present. Only works for functions with a single input and output. + * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * works for functions with a single input and output. * - * @param tf the scope to call the function in + * @param tf the scope to call the function in * @param argument the argument to the call * @return the output of the function */ @@ -385,19 +363,6 @@ public Operand call(Ops tf, Operand argument) { return tf.call(this, argument); } - /** - * Export this function as a saved model. - * - *

This method is convenient shortcut equivalent to - * {@code SavedModel.exporter(exportDir).withFunction(this).export()} - * - * @param exportDir directory where to export the saved model - * @throws IOException if saved model or variable state cannot be written on disk - */ - public void save(String exportDir) throws IOException { - SavedModelBundle.exporter(exportDir).withFunction(this).export(); - } - TF_Function nativeHandle() { if (nativeFunction.getNativeHandle().isNull()) { throw new IllegalStateException("Function has been closed"); @@ -414,8 +379,8 @@ TF_Function nativeHandle() { } /** - * Detects the signature from the handle. Does not close passed functions. All passed functions - * should have deallocators. + * Detects the signature from the handle. Does not close passed functions. All passed functions should have + * deallocators. */ static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction, Collection availableFunctions) { @@ -524,11 +489,11 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, } /** - * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because - * how to enable XLA JIT is extremely non-obvious. + * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA + * JIT is extremely non-obvious. *

- * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered - * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). + * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id: + * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */ private void makeJit() { try (PointerScope scope = new PointerScope()) { @@ -599,18 +564,18 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) Reference ref = graph.ref()) { TF_Status status = TF_Status.newStatus(); - List> inputs = signature.getInputs().values().stream() - .map((x) -> graph.outputOrThrow(x.name)) + List> inputs = signature.getInputs().entrySet().stream() + .map((x) -> CallableFunction.validateDescription(x.getValue(), graph, x.getKey(), "Input")) .collect(Collectors.toList()); - List> outputs = signature.getOutputs().values().stream() - .map((x) -> graph.outputOrThrow(x.name)) + List> outputs = signature.getOutputs().entrySet().stream() + .map((x) -> CallableFunction.validateDescription(x.getValue(), graph, x.getKey(), "Output")) .collect(Collectors.toList()); List ops = new ArrayList<>( graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs))); - inputs.forEach(input -> ops.remove(input.op())); + inputs.forEach(input -> ops.remove((GraphOperation) input.op())); ops.forEach(x -> { if (x.type().equals(Placeholder.OP_NAME) || x.type() diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index d98c1507083..bbd848c0ceb 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -142,6 +142,21 @@ public Exporter withTags(String... tags) { return this; } + /** + * Set the session to export, without adding any signatures. This enables the use of {@link + * #withSignature(Signature)} + * + * @throws IllegalStateException if the session is already set to a different session + */ + public Exporter withSession(Session session) { + if (this.session != null && this.session != session) { + throw new IllegalStateException( + "This exporter already has a session that differs from the passed session"); + } + this.session = session; + return this; + } + /** * Save a concrete function of this model. * @@ -154,24 +169,45 @@ public Exporter withTags(String... tags) { * TensorFlow 2.x estimators. * *
Therefore, all functions exported in a model should share the same session at the moment - * or an exception will be thrown. + * or an exception will be thrown. This applies to sessions set via {@link #withSession(Session)} as well, the + * exporter can only even have one session. * * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object * @throws IllegalArgumentException if a function with the same name has already been added to the model - * @throws UnsupportedOperationException if this function does not share the same session with the other functions - * added to this model + * @throws UnsupportedOperationException if the session is already set to a different session */ - public Exporter withFunction(ConcreteFunction function) { + public Exporter withFunction(SessionFunction function) { Signature signature = function.signature(); if (functions.containsKey(signature.key())) { throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model"); } + if (session != null && session != function.session()) { + throw new UnsupportedOperationException( + "This exporter already has a session that differs from the passed function's session"); + } + + session = function.session(); functions.put(signature.key(), function); metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); return this; } + /** + * Add a signature to the model. This wraps the signature in a {@link SessionFunction} using the exporter's + * already-set session. As such, either {@link #withSession(Session)} or {@link #withFunction(SessionFunction)} + * must be called before this method. + * + * @throws IllegalStateException if no session has been set + */ + public Exporter withSignature(Signature signature) { + if (session == null) { + throw new IllegalStateException( + "Session has not been set yet, you must call withSession or withFunction first."); + } + return withFunction(session.function(signature)); + } + /** * Save the model into the export directory. * @@ -181,36 +217,30 @@ public void export() throws IOException { if (functions.isEmpty()) { throw new IllegalStateException("Model should contain at least one valid function"); } - try (Graph graph = new Graph(); - Session session = new Session(graph)) { - - functions.values().forEach(graph::attachFunction); - - session.runInit(); - - // It is imperative to retrieve the graphDef after the saverDef, as the former might add - // new ops to the graph for saving and restoring the variables. - SaverDef saverDef = graph.saverDef(); - - MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder - .setSaverDef(saverDef) - .setGraphDef(graph.toGraphDef()) - .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); - functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); - - // Make sure saved model directories exist - Path variableDir = Paths.get(exportDir, "variables"); - variableDir.toFile().mkdirs(); - - // Save the variables state - session.save(variableDir.resolve("variables").toString()); - - // Save the graph - SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); - try (OutputStream file = - new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { - savedModelDef.writeTo(file); - } + Graph graph = session.graph(); + + // It is imperative to retrieve the graphDef after the saverDef, as the former might add + // new ops to the graph for saving and restoring the variables. + SaverDef saverDef = graph.saverDef(); + + MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder + .setSaverDef(saverDef) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); + functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); + + // Make sure saved model directories exist + Path variableDir = Paths.get(exportDir, "variables"); + variableDir.toFile().mkdirs(); + + // Save the variables state + session.save(variableDir.resolve("variables").toString()); + + // Save the graph + SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); + try (OutputStream file = + new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { + savedModelDef.writeTo(file); } } @@ -221,106 +251,8 @@ public void export() throws IOException { private final String exportDir; private String[] tags = {DEFAULT_TAG}; private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); - private final Map functions = new LinkedHashMap<>(); - } - - /** - * A function loaded from a saved model. It can be called using the saved model's session. - * - * All resources are owned by the SavedModel. - * - * The session is not initialized in any way, you can use {@link Session#runInit()} on {@link SavedModelBundle#session()} - * if this is necessary. - */ - public final class SavedFunction { - - /** - * The signature of the function. - */ - public Signature signature() { - return signature; - } - - /** - * The name of the function. - */ - public String name() { - return signature.key(); - } - - /** - * Call this function using the SavedModel's session. - * - *

Caller is responsible for closing all returned Tensors. - * - * @throws IllegalArgumentException if an argument is missing, an argument is passed for an unknown parameter, - * or an argument has the wrong type - */ - public Map call(Map arguments) { - Session.Runner runner = session.runner(); - arguments.forEach((name, value) -> { - Signature.TensorDescription parameter = signature.getInputs().get(name); - if (parameter == null) { - throw new IllegalArgumentException("Function \"" + name() + "\" has no argument \"" + name + "\"."); - } - if (value.dataType() != parameter.dataType) { - throw new IllegalArgumentException("Function \"" + name() + "\"'s argument \"" + name + - "\" has data type " + parameter.dataType + ", but a tensor of data type " + value.dataType() - + " was passed."); - } - runner.feed(parameter.name, value); - }); - - signature.inputNames().forEach((param) -> { - if (!arguments.containsKey(param)) { - throw new IllegalArgumentException( - "Function \"" + name() + "\" has a parameter \"" + param + "\", but no argument was passed for it."); - } - }); - - List resultNames = new ArrayList<>(signature.getOutputs().size()); - signature.getOutputs().forEach((name, desc) -> { - runner.fetch(desc.name); - resultNames.add(name); - }); - - List result = runner.run(); - Map namedResults = new LinkedHashMap<>(result.size()); - - for (int i = 0; i < result.size(); i++) { - namedResults.put(resultNames.get(i), result.get(i)); - } - return namedResults; - } - - - /** - * Call this single-argument single-result function using the SavedModel's session. - * - *

Caller is responsible for closing the returned Tensor. - * - * @throws IllegalStateException if this function does not have exactly one input and output. - */ - public Tensor call(Tensor argument) { - if (signature.getInputs().size() != 1) { - throw new IllegalStateException("Can only use this call method on functions with exactly one input, function \"" - + name() + "\" has " + signature.getInputs().size() + "."); - } - if (signature.getOutputs().size() != 1) { - throw new IllegalStateException("Can only use this call method on functions with exactly one input, function \"" - + name() + "\" has " + signature.getInputs().size() + "."); - } - Map inputMap = new LinkedHashMap<>(1); - inputMap.put(signature.inputNames().iterator().next(), argument); - Map results = call(inputMap); - return results.get(signature.outputNames().iterator().next()); - } - - private final Signature signature; - - private SavedFunction(Signature signature) { - this.signature = signature; - } + private Session session; + private final Map functions = new LinkedHashMap<>(); } /** @@ -414,8 +346,8 @@ public List signatures() { * @return object that can be used to make calls to a function * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ - public SavedFunction function(String signatureKey) { - SavedFunction function = functions.get(signatureKey); + public SessionFunction function(String signatureKey) { + SessionFunction function = functions.get(signatureKey); if (function == null) { throw new IllegalArgumentException( String.format("Function with signature [%s] not found", signatureKey)); @@ -426,7 +358,7 @@ public SavedFunction function(String signatureKey) { /** * Get all functions in the bundle. */ - public List functions() { + public List functions() { return new ArrayList<>(functions.values()); } @@ -447,7 +379,7 @@ public List functions() { * @throws IllegalArgumentException if no function can be selected by default */ public Map call(Map arguments) { - SavedFunction function = null; + SessionFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); } else { @@ -471,7 +403,7 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; - private final Map functions; + private final Map functions; private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map signatures) { @@ -479,7 +411,7 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef this.session = session; this.metaGraphDef = metaGraphDef; this.functions = signatures.entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, e -> new SavedFunction(e.getValue()))); + .collect(Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session))); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 51079dbbcea..aa2767a9352 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -23,6 +23,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -504,6 +505,24 @@ public void run(Op op) { runner().addTarget(op.op()).run(); } + /** + * Create a new session function, backed by this session. + * + * @param signature the signature of the function. + */ + public SessionFunction function(Signature signature) { + return new SessionFunction(signature, this); + } + + /** + * Create and call a function, returning the results. + * + * @param signature the signature of the function + * @param arguments the arguments to call with. + */ + public Map run(Signature signature, Map arguments) { + return function(signature).call(arguments); + } /** * Execute the graph's initializers. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java new file mode 100644 index 00000000000..1100d5d849f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -0,0 +1,100 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * A callable function backed by a session. All calls of this function will be ran on the same session. + * + * Does no resource management, the session and all returned tensors are the caller's responsibility. + * + * Does not initialize the session, since it may be shared. + */ +public class SessionFunction implements CallableFunction { + + private final Signature signature; + private final Session session; + + public SessionFunction(Signature signature, Session session) { + this.signature = signature; + this.session = session; + + signature.getInputs().forEach((name, description) -> { + CallableFunction.validateDescription(description, session.graph(), name, "Input"); + }); + + signature.getInputs().forEach((name, description) -> { + CallableFunction.validateDescription(description, session.graph(), name, "Output"); + }); + } + + public static SessionFunction create(Signature signature, Session session) { + return new SessionFunction(signature, session); + } + + @Override + public Signature signature() { + return signature; + } + + public Session session() { + return session; + } + + /** + * Get a new function with the same signature, but backed by a new session. + * + * @param session the new backing session. + */ + public SessionFunction withNewSession(Session session) { + return new SessionFunction(signature, session); + } + + @Override + public Map call(Map arguments) { + Session.Runner runner = session.runner(); + signature.getInputs().forEach((argName, operand) -> { + if (!arguments.containsKey(argName)) { + throw new IllegalArgumentException("No argument found for parameter \"" + argName + "\""); + } + Tensor value = arguments.get(argName); + + if (value == null) { + throw new IllegalArgumentException( + "Can't pass null as an argument to a function. Argument \"" + argName + "\" was null."); + } + + runner.feed(operand.name, value); + }); + + signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); + + List results = runner.run(); + + Map outputs = new LinkedHashMap<>(results.size()); + int i = 0; + for (String outputName : signature.outputNames()) { + outputs.put(outputName, results.get(i)); + i++; + } + + return outputs; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 24ec48a1cbd..6f4ce78a63e 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -26,13 +26,10 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; -import org.tensorflow.SavedModelBundle.SavedFunction; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Ops; @@ -43,11 +40,11 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; -import org.tensorflow.proto.framework.SignatureDef; -import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.types.TFloat32; -/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ +/** + * Unit tests for {@link org.tensorflow.SavedModelBundle}. + */ public class SavedModelBundleTest { private static final float EPSILON = 1e-7f; @@ -57,7 +54,8 @@ public class SavedModelBundleTest { static { try { SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); - SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()).toString(); + SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()) + .toString(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -97,76 +95,76 @@ public void loader() { } } - @Test - public void exportFunctionWithVariables() throws IOException { - Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); - float reducedSum; - FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); - Shape xyShape = Shape.of(2, 3L); - try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { - // Init variable state by running the Init operation directly - //TODO f.session().run(Init.DEFAULT_NAME); - - // Call the graph and remember the result of computation for later - try (TFloat32 xTensor = TFloat32.tensorOf(xValue); - TFloat32 zTensor = (TFloat32)f.call(xTensor)) { - reducedSum = zTensor.getFloat(); - } - // Save/export the model (which is a single function in this case) - f.save(testFolder.toString()); - } - assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); - assertTrue(Files - .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); - assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); - - // Reload the model just saved and validate its data - try (SavedModelBundle savedModel = - SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { - assertNotNull(savedModel.metaGraphDef()); - assertNotNull(savedModel.metaGraphDef().getSaverDef()); - assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); - assertEquals(Signature.DEFAULT_KEY, - savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); - - SavedFunction function = savedModel.function(Signature.DEFAULT_KEY); - assertNotNull(function); - - Signature signature = function.signature(); - assertNotNull(signature); - assertEquals(1, signature.inputNames().size()); - assertEquals("input", signature.inputNames().iterator().next()); - assertEquals(1, signature.outputNames().size()); - assertEquals("reducedSum", signature.outputNames().iterator().next()); - - SignatureDef signatureDef = signature.asSignatureDef(); - assertEquals(1, signatureDef.getInputsCount()); - assertEquals(1, signatureDef.getOutputsCount()); - - TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); - assertNotNull(inputInfo); - assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); - for (int i = 0; i < xyShape.numDimensions(); ++i) { - assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); - } - - TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); - assertNotNull(outputInfo); - assertEquals(0, outputInfo.getTensorShape().getDimCount()); - - try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { - // Call the saved model function and make sure it returns the same result as before - try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { - assertEquals(reducedSum, zTensor.getFloat(), EPSILON); - } - // Now call the same function directly from the model - try (TFloat32 zTensor = - (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { - assertEquals(reducedSum, zTensor.getFloat(), EPSILON); - } - } - } - } +// @Test +// public void exportFunctionWithVariables() throws IOException { +// Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); +// float reducedSum; +// FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); +// Shape xyShape = Shape.of(2, 3L); +// try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { +// // Init variable state by running the Init operation directly +// //TODO f.session().run(Init.DEFAULT_NAME); +// +// // Call the graph and remember the result of computation for later +// try (TFloat32 xTensor = TFloat32.tensorOf(xValue); +// TFloat32 zTensor = (TFloat32)f.call(xTensor)) { +// reducedSum = zTensor.getFloat(); +// } +// // Save/export the model (which is a single function in this case) +// f.save(testFolder.toString()); +// } +// assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); +// assertTrue(Files +// .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); +// assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); +// +// // Reload the model just saved and validate its data +// try (SavedModelBundle savedModel = +// SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { +// assertNotNull(savedModel.metaGraphDef()); +// assertNotNull(savedModel.metaGraphDef().getSaverDef()); +// assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); +// assertEquals(Signature.DEFAULT_KEY, +// savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); +// +// SessionFunction function = savedModel.function(Signature.DEFAULT_KEY); +// assertNotNull(function); +// +// Signature signature = function.signature(); +// assertNotNull(signature); +// assertEquals(1, signature.inputNames().size()); +// assertEquals("input", signature.inputNames().iterator().next()); +// assertEquals(1, signature.outputNames().size()); +// assertEquals("reducedSum", signature.outputNames().iterator().next()); +// +// SignatureDef signatureDef = signature.asSignatureDef(); +// assertEquals(1, signatureDef.getInputsCount()); +// assertEquals(1, signatureDef.getOutputsCount()); +// +// TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); +// assertNotNull(inputInfo); +// assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); +// for (int i = 0; i < xyShape.numDimensions(); ++i) { +// assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); +// } +// +// TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); +// assertNotNull(outputInfo); +// assertEquals(0, outputInfo.getTensorShape().getDimCount()); +// +// try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { +// // Call the saved model function and make sure it returns the same result as before +// try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { +// assertEquals(reducedSum, zTensor.getFloat(), EPSILON); +// } +// // Now call the same function directly from the model +// try (TFloat32 zTensor = +// (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { +// assertEquals(reducedSum, zTensor.getFloat(), EPSILON); +// } +// } +// } +// } @Test public void exportMultipleFunctions() throws IOException { @@ -176,12 +174,12 @@ public void exportMultipleFunctions() throws IOException { Ops tf = Ops.create(g); Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); Signature f2Signature = buildIdentityGraph(tf, "identity"); - try (Session s = new Session(g); - ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { - //TODO f1.session().run(Init.DEFAULT_NAME); + try (Session s = new Session(g);) { + SessionFunction f1 = SessionFunction.create(f1Signature, s); + SessionFunction f2 = SessionFunction.create(f2Signature, s); + s.runInit(); try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { + TFloat32 t = (TFloat32) f1.call(x)) { reducedSum = t.getFloat(); } SavedModelBundle.exporter(testFolder.toString()) @@ -192,16 +190,16 @@ public void exportMultipleFunctions() throws IOException { } try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { assertEquals(2, model.signatures().size()); - SavedFunction f1 = model.function(Signature.DEFAULT_KEY); + SessionFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { + TFloat32 t = (TFloat32) f1.call(x)) { assertEquals(reducedSum, t.getFloat(), EPSILON); } - SavedFunction f2 = model.function("identity"); + SessionFunction f2 = model.function("identity"); assertNotNull(f2); try (TFloat32 x = TFloat32.scalarOf(10.0f); - TFloat32 t = (TFloat32)f2.call(x)) { + TFloat32 t = (TFloat32) f2.call(x)) { assertEquals(10.0f, t.getFloat(), 0.0f); } try { @@ -220,10 +218,10 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti Ops tf = Ops.create(g); Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_KEY); - try (Session s = new Session(g); - ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { - //TODO f1.session().run(Init.DEFAULT_NAME); + try (Session s = new Session(g);) { + SessionFunction f1 = SessionFunction.create(f1Signature, s); + SessionFunction f2 = SessionFunction.create(f2Signature, s); + s.runInit(); try { SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) @@ -267,7 +265,7 @@ public void pythonTfFunction() { * Test model was created in python * Signature name used for saving 'add', argument names 'a' and 'b' */ - SavedFunction add = bundle.function("add"); + SessionFunction add = bundle.function("add"); Map args = new HashMap<>(); try (TFloat32 a = TFloat32.scalarOf(10.0f); TFloat32 b = TFloat32.scalarOf(15.5f)) { From 12af3277bdfd0f7e2442480f53755c70b779a7f5 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 14 May 2021 19:25:10 -0700 Subject: [PATCH 22/34] Add CallableFunction javadoc Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/CallableFunction.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java index 7470c0a6a26..1aec98198ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java @@ -20,6 +20,9 @@ import java.util.Map; import org.tensorflow.Signature.TensorDescription; +/** + * A function that can be called with tensors. + */ public interface CallableFunction { /** From 1b4bf5954a159519735f0e60e042cfefb34d5b13 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 14 May 2021 19:26:07 -0700 Subject: [PATCH 23/34] Remove obsolete test Signed-off-by: Ryan Nett --- .../org/tensorflow/SavedModelBundleTest.java | 71 ------------------- 1 file changed, 71 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 6f4ce78a63e..34c81ca4260 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -95,77 +95,6 @@ public void loader() { } } -// @Test -// public void exportFunctionWithVariables() throws IOException { -// Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); -// float reducedSum; -// FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); -// Shape xyShape = Shape.of(2, 3L); -// try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { -// // Init variable state by running the Init operation directly -// //TODO f.session().run(Init.DEFAULT_NAME); -// -// // Call the graph and remember the result of computation for later -// try (TFloat32 xTensor = TFloat32.tensorOf(xValue); -// TFloat32 zTensor = (TFloat32)f.call(xTensor)) { -// reducedSum = zTensor.getFloat(); -// } -// // Save/export the model (which is a single function in this case) -// f.save(testFolder.toString()); -// } -// assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); -// assertTrue(Files -// .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); -// assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); -// -// // Reload the model just saved and validate its data -// try (SavedModelBundle savedModel = -// SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { -// assertNotNull(savedModel.metaGraphDef()); -// assertNotNull(savedModel.metaGraphDef().getSaverDef()); -// assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); -// assertEquals(Signature.DEFAULT_KEY, -// savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); -// -// SessionFunction function = savedModel.function(Signature.DEFAULT_KEY); -// assertNotNull(function); -// -// Signature signature = function.signature(); -// assertNotNull(signature); -// assertEquals(1, signature.inputNames().size()); -// assertEquals("input", signature.inputNames().iterator().next()); -// assertEquals(1, signature.outputNames().size()); -// assertEquals("reducedSum", signature.outputNames().iterator().next()); -// -// SignatureDef signatureDef = signature.asSignatureDef(); -// assertEquals(1, signatureDef.getInputsCount()); -// assertEquals(1, signatureDef.getOutputsCount()); -// -// TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); -// assertNotNull(inputInfo); -// assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); -// for (int i = 0; i < xyShape.numDimensions(); ++i) { -// assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); -// } -// -// TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); -// assertNotNull(outputInfo); -// assertEquals(0, outputInfo.getTensorShape().getDimCount()); -// -// try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { -// // Call the saved model function and make sure it returns the same result as before -// try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { -// assertEquals(reducedSum, zTensor.getFloat(), EPSILON); -// } -// // Now call the same function directly from the model -// try (TFloat32 zTensor = -// (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { -// assertEquals(reducedSum, zTensor.getFloat(), EPSILON); -// } -// } -// } -// } - @Test public void exportMultipleFunctions() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); From 54e88555faa19f250c5a547d85ec5febb6f152a5 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 14 May 2021 19:40:55 -0700 Subject: [PATCH 24/34] Rebase fix Signed-off-by: Ryan Nett --- .../src/test/java/org/tensorflow/SavedModelBundleTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 34c81ca4260..3501a77b590 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -210,13 +210,13 @@ public void pythonTfFunction() { args.clear(); // variable unwrapping happens in Session, which is used by ConcreteFunction.call - ConcreteFunction getVariable = bundle.function("get_variable"); + SessionFunction getVariable = bundle.function("get_variable"); try (TFloat32 dummy = TFloat32.scalarOf(1.0f)) { - args.put("dummy",dummy); + args.put("dummy", dummy); // TF functions always require an input, so we supply a dummy one here // This test actually checks that resource variables can be loaded correctly. try (TFloat32 v = (TFloat32) getVariable.call(args) - .get(getVariable.signature().outputNames().iterator().next())) { + .get(getVariable.signature().outputNames().iterator().next())) { assertEquals(2f, v.getFloat()); } } From 117b3918ec26e881c25d1c3616f02d2f2e7da484 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 21 May 2021 16:51:08 -0700 Subject: [PATCH 25/34] Formatting fixes and nits Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 6 +- .../java/org/tensorflow/CallableFunction.java | 94 ++-- .../java/org/tensorflow/ConcreteFunction.java | 306 ++++++------ .../org/tensorflow/EagerOperationBuilder.java | 39 +- .../java/org/tensorflow/EagerSession.java | 13 +- .../src/main/java/org/tensorflow/Graph.java | 24 +- .../java/org/tensorflow/NativeFunction.java | 64 ++- .../java/org/tensorflow/SavedModelBundle.java | 145 +++--- .../src/main/java/org/tensorflow/Session.java | 163 ++++--- .../main/java/org/tensorflow/Signature.java | 118 +++-- .../main/java/org/tensorflow/TensorFlow.java | 37 +- .../internal/c_api/presets/tensorflow.java | 442 ++++++++++++------ .../java/org/tensorflow/op/core/Function.java | 46 +- 13 files changed, 851 insertions(+), 646 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ec0c10bfd76..b405b8a39d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1120,10 +1120,9 @@ public Bucketize bucketize(Operand input, List boundar } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only + * Calls the function in an execution environment, adding its graph as a function if it isn't already present. Only * works for functions with a single input and output. * - * @param scope the scope to call the function in * @param argument the argument to the call * @return the output of the function * @see ConcreteFunction#call(Ops, Operand) @@ -1133,10 +1132,9 @@ public Operand call(ConcreteFunction function, Operand argument) { } /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The + * Calls the function in an execution environment, adding its graph as a function if it isn't already present. The * inputs and outputs are keyed by the names set in the {@code Signature}. * - * @param scope the scope to call the function in * @param arguments the arguments to the call * @return the outputs of the function * @see ConcreteFunction#call(Ops, Map) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java index 1aec98198ac..bcb5c775f74 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java @@ -1,33 +1,29 @@ /* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ package org.tensorflow; import java.util.LinkedHashMap; import java.util.Map; import org.tensorflow.Signature.TensorDescription; -/** - * A function that can be called with tensors. - */ +/** A function that can be called with tensors. */ public interface CallableFunction { - /** - * Returns the signature of this function - */ + /** Returns the signature of this function */ Signature signature(); /** @@ -35,9 +31,12 @@ public interface CallableFunction { * *

Caller is responsible for closing all Tensors. * - * @param arguments list of tensors to pass in input to the function, mapped by their signature name - * @return output tensors resulting from the execution of the function, mapped by their signature name - * @throws IllegalArgumentException if the passed arguments don't match up to the function's parameters. + * @param arguments list of tensors to pass in input to the function, mapped by their signature + * name + * @return output tensors resulting from the execution of the function, mapped by their signature + * name + * @throws IllegalArgumentException if the passed arguments don't match up to the function's + * parameters. */ Map call(Map arguments); @@ -48,24 +47,33 @@ public interface CallableFunction { * * @param tensor input tensor * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function + * @throws IllegalArgumentException if there are multiple input or output parameters defined in + * the function */ default Tensor call(Tensor tensor) { if (signature().inputNames().size() > 1) { throw new IllegalArgumentException( - "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with more than one input."); + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with more than one input."); } if (signature().inputNames().size() < 1) { throw new IllegalArgumentException( - "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with no inputs."); + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with no inputs."); } if (signature().outputNames().size() > 1) { throw new IllegalArgumentException( - "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with more than one output."); + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with more than one output."); } if (signature().outputNames().size() < 1) { throw new IllegalArgumentException( - "Can't use call(Tensor) on function \"" + signature().methodName() + "\" with no outputs."); + "Can't use call(Tensor) on function \"" + + signature().methodName() + + "\" with no outputs."); } String inputName = signature().inputNames().iterator().next(); @@ -77,23 +85,45 @@ default Tensor call(Tensor tensor) { return call(inputMap).get(outputName); } - static Operand validateDescription(TensorDescription description, Graph graph, String name, String prefix) { + static Operand validateDescription( + TensorDescription description, Graph graph, String name, String prefix) { Output operand = graph.output(description.name); if (operand == null) { throw new IllegalArgumentException( - prefix + " \"" + name + "\"'s operand \"" + description.name + "\" does not exist on the session's graph."); + prefix + + " \"" + + name + + "\"'s operand \"" + + description.name + + "\" does not exist on the graph."); } if (operand.dataType() != description.dataType) { throw new IllegalArgumentException( - prefix + " \"" + name + "\"'s operand \"" + description.name + "\" has data type " + operand.dataType() - + " in the session's graph, but the signature requires data type " + description.dataType + "."); + prefix + + " \"" + + name + + "\"'s operand \"" + + description.name + + "\" has data type " + + operand.dataType() + + " in the graph, but the signature requires data type " + + description.dataType + + "."); } if (!operand.shape().isCompatibleWith(description.shape)) { throw new IllegalArgumentException( - prefix + " \"" + name + "\"'s operand \"" + description.name + "\" has shape " + operand.shape() - + ", which is incompatible with the signature's required shape of " + description.shape + "."); + prefix + + " \"" + + name + + "\"'s operand \"" + + description.name + + "\" has shape " + + operand.shape() + + ", which is incompatible with the signature's required shape of " + + description.shape + + "."); } return operand; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index a66b66998f5..16b9c2817fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -55,9 +55,9 @@ /** * A graph that can be invoked as a single function, with an input and output signature. * - *

A function can also invoke a - * tf.function - * defined in a {@link SavedModelBundle}. + *

A function can also invoke a tf.function defined in a {@link + * SavedModelBundle}. * *

{@code
  * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
@@ -66,17 +66,16 @@
  */
 public class ConcreteFunction implements AutoCloseable, CallableFunction {
 
-
   /**
    * Creates a function by building a new graph.
    *
-   * 

The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output - * tensors on execution. + *

The {@code functionBuilder} must initialize the function graph from the provided {@link Ops} + * instance and return a valid signature that will be used to feed the input tensors and fetch the + * output tensors on execution. * - *

The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will - * be freed once the function is discarded. For example: + *

The function will be the owner of the new graph and its resulting session. Therefore, the + * function must be enclosed properly with a try-with-resources block to guarantee that all native + * resources will be freed once the function is discarded. For example: * *

{@code
    * public class MyModel {
@@ -110,9 +109,9 @@ public static ConcreteFunction create(Function functionBuilder)
   /**
    * Create a function from a signature and an existing graph.
    *
-   * 

The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For - * example: + *

The function will keep the ownership of the session used to run the graph but not the graph + * itself, meaning that the lifetime of the latter can extend beyond the scope of the function. + * For example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -139,9 +138,9 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
   /**
    * Create a function from a signature and a valid graph session.
    *
-   * 

The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For - * example: + *

The function will not own the session nor its graph, meaning that their lifetime can extend + * beyond the scope of the function. Therefore the function does not need to be closed after its + * usage. For example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -170,31 +169,26 @@ public static ConcreteFunction create(Signature signature, Session session) {
     return buildFromGraph(session.graph(), signature);
   }
 
-  /**
-   * Returns the signature of this function
-   */
+  /** Returns the signature of this function */
   @Override
   public Signature signature() {
     return signature;
   }
 
   /**
-   * Get the name of the function.
+   * Get the name of the function. This is what it will show up under in the graph and any exported
+   * GraphDefs.
    */
   public String getNativeFunctionName() {
     return nativeFunction.getName();
   }
 
-  /**
-   * Get the {@link FunctionDef} proto.
-   */
+  /** Get the {@link FunctionDef} proto. */
   public FunctionDef getFunctionDef() {
     return nativeFunction.getFunctionDef();
   }
 
-  /**
-   * Get whether the function is stateful.
-   */
+  /** Get whether the function is stateful. */
   public boolean isStateful() {
     return nativeFunction.isStateful();
   }
@@ -213,22 +207,20 @@ public String toString() {
     return signature.toString();
   }
 
-  //TODO migrate to the actual ops once they are generated
+  // TODO migrate to the actual ops once they are generated
   public static final String CALL_OP = "PartitionedCall";
-  //TODO migrate to the actual ops once they are generated
+  // TODO migrate to the actual ops once they are generated
   public static final String STATEFUL_CALL_OP = "StatefulPartitionedCall";
 
-
   /**
-   * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
-   * inputs and outputs are keyed by the names set in the {@code Signature}.
+   * Calls the function in an execution environment, adding its graph as a function if it isn't
+   * already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
    *
    * @param scope the scope to call the function in
    * @param arguments the arguments to the call
    * @return the outputs of the function
    */
-  public Map> call(Scope scope,
-      Map> arguments) {
+  public Map> call(Scope scope, Map> arguments) {
     List> inputList = new ArrayList<>();
 
     Output[] inputs = new Output[signature().inputNames().size()];
@@ -237,14 +229,19 @@ public Map> call(Scope scope,
     for (String inputName : signature().inputNames()) {
       if (!arguments.containsKey(inputName)) {
         throw new IllegalArgumentException(
-            "Function " + signature().methodName() + " has parameter \"" + inputName
+            "Function "
+                + signature().methodName()
+                + " has parameter \""
+                + inputName
                 + "\", but no argument was passed for it.");
       }
 
       Operand input = arguments.get(inputName);
       if (input == null) {
         throw new IllegalArgumentException(
-            "Can't pass null as an argument to a function.  Argument \"" + inputName + "\" was null.");
+            "Can't pass null as an argument to a function.  Argument \""
+                + inputName
+                + "\" was null.");
       }
       inputs[i] = input.asOutput();
       i++;
@@ -255,8 +252,10 @@ public Map> call(Scope scope,
 
     String displayName = Scope.isValidOpName(name) ? name : "FunctionCall";
 
-    OperationBuilder opBuilder = scope.env()
-        .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));
+    OperationBuilder opBuilder =
+        scope
+            .env()
+            .opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));
 
     opBuilder.addInputList(inputs);
 
@@ -293,8 +292,8 @@ public Map> call(Scope scope,
   }
 
   /**
-   * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
-   * works for functions with a single input and output.
+   * Calls the function in an execution environment, adding its graph as a function if it isn't
+   * already present. Only works for functions with a single input and output.
    *
    * @param scope the scope to call the function in
    * @param argument the argument to the call
@@ -323,7 +322,7 @@ public Operand call(Scope scope, Operand argument) {
 
   @Override
   public Map call(Map arguments) {
-    //FIXME need to manage input/output operand lifetimes
+    // FIXME need to manage input/output operand lifetimes
     Ops tf = Ops.create();
     Map> inputs = new LinkedHashMap<>(arguments.size());
 
@@ -340,8 +339,8 @@ public Map call(Map arguments) {
   }
 
   /**
-   * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
-   * inputs and outputs are keyed by the names set in the {@code Signature}.
+   * Calls the function in an execution environment, adding its graph as a function if it isn't
+   * already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
    *
    * @param tf the scope to call the function in
    * @param arguments the arguments to the call
@@ -352,8 +351,8 @@ public Map> call(Ops tf, Map> arguments) {
   }
 
   /**
-   * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
-   * works for functions with a single input and output.
+   * Calls the function in an execution environment, adding its graph as a function if it isn't
+   * already present. Only works for functions with a single input and output.
    *
    * @param tf the scope to call the function in
    * @param argument the argument to the call
@@ -370,50 +369,49 @@ TF_Function nativeHandle() {
     return nativeFunction.getNativeHandle();
   }
 
-  /**
-   * All native functions should have deallocators registered
-   */
-  ConcreteFunction(Signature signature, NativeFunction nativeFunction,
+  /** All native functions should have deallocators registered */
+  ConcreteFunction(
+      Signature signature,
+      NativeFunction nativeFunction,
       Collection availableFunctions) {
     this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions));
   }
 
   /**
-   * Detects the signature from the handle. Does not close passed functions. All passed functions should have
-   * deallocators.
+   * Detects the signature from the handle. Does not close passed functions. All passed functions
+   * should have deallocators.
    */
-  static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
-      Collection availableFunctions) {
+  static ConcreteFunction fromNativeHandle(
+      NativeFunction nativeFunction, Collection availableFunctions) {
 
-    Signature.Builder builder = Signature.builder()
-        .methodName(nativeFunction.getFunctionDef().getSignature().getName())
-        .key(nativeFunction.getName());
+    Signature.Builder builder =
+        Signature.builder()
+            .methodName(nativeFunction.getFunctionDef().getSignature().getName())
+            .key(nativeFunction.getName());
 
     for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
-      TensorInfo info = TensorInfo.newBuilder()
-          .setDtype(input.getType())
-          .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
-          .setName(input.getName())
-          .build();
+      TensorInfo info =
+          TensorInfo.newBuilder()
+              .setDtype(input.getType())
+              .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
+              .setName(input.getName())
+              .build();
 
       builder.input(input.getName(), info);
     }
 
     for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) {
-      TensorInfo info = TensorInfo.newBuilder()
-          .setDtype(outputDef.getType())
-          .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
-          .setName(outputDef.getName())
-          .build();
+      TensorInfo info =
+          TensorInfo.newBuilder()
+              .setDtype(outputDef.getType())
+              .setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
+              .setName(outputDef.getName())
+              .build();
 
       builder.output(outputDef.getName(), info);
     }
 
-    return new ConcreteFunction(
-        builder.build(),
-        nativeFunction,
-        availableFunctions
-    );
+    return new ConcreteFunction(builder.build(), nativeFunction, availableFunctions);
   }
 
   private final Signature signature;
@@ -423,60 +421,62 @@ static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
   private final DataType[] inputDtypes;
   private final DataType[] outputDtypes;
 
-
-  /**
-   * All native functions should have deallocators registered
-   */
-  private ConcreteFunction(Signature signature, NativeFunction nativeFunction,
-      Set dependencies) {
+  /** All native functions should have deallocators registered */
+  private ConcreteFunction(
+      Signature signature, NativeFunction nativeFunction, Set dependencies) {
     this.signature = signature;
     this.nativeFunction = nativeFunction;
     this.dependencies = Collections.unmodifiableSet(dependencies);
 
-    if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature()
-        .getInputArgCount()) {
+    if (this.signature.getInputs().size()
+        != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
       throw new IllegalArgumentException(
           "Signature must have the same number of inputs as the native function.  Expected "
-              + nativeFunction.getFunctionDef().getSignature().getInputArgCount() + ", got "
+              + nativeFunction.getFunctionDef().getSignature().getInputArgCount()
+              + ", got "
               + this.signature.getInputs().size());
     }
 
-    if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature()
-        .getOutputArgCount()) {
+    if (this.signature.getOutputs().size()
+        != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
       throw new IllegalArgumentException(
           "New signature must have the same number of outputs as the native function.  Expected "
-              + nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + ", got "
+              + nativeFunction.getFunctionDef().getSignature().getOutputArgCount()
+              + ", got "
               + this.signature.getOutputs().size());
     }
 
-    inputDtypes = this.signature.getInputs().values().stream().map(x -> x.dataType)
-        .toArray(DataType[]::new);
+    inputDtypes =
+        this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
 
     List inputs = Arrays.asList(inputDtypes);
-    List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList()
-        .stream()
-        .map(ArgDef::getType)
-        .collect(Collectors.toList());
+    List nativeInputs =
+        nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
+            .map(ArgDef::getType)
+            .collect(Collectors.toList());
 
     if (!dataTypesMatch(inputs, nativeInputs)) {
       throw new IllegalArgumentException(
           "Data types of the signature's inputs must match the native function's (in order).  Expected "
-              + nativeInputs + ", got " + inputs);
+              + nativeInputs
+              + ", got "
+              + inputs);
     }
 
-    outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType)
-        .toArray(DataType[]::new);
+    outputDtypes =
+        signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
 
     List outputs = Arrays.asList(outputDtypes);
-    List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList()
-        .stream()
-        .map(ArgDef::getType)
-        .collect(Collectors.toList());
+    List nativeOutputs =
+        nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
+            .map(ArgDef::getType)
+            .collect(Collectors.toList());
 
     if (!dataTypesMatch(outputs, nativeOutputs)) {
       throw new IllegalArgumentException(
           "Data types of the signature's outputs must match the native function's (in order).  Expected "
-              + nativeOutputs + ", got "
+              + nativeOutputs
+              + ", got "
               + outputs);
     }
 
@@ -489,11 +489,11 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction,
   }
 
   /**
-   * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
-   * JIT is extremely non-obvious.
-   * 

- * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id: - * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). + * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because + * how to enable XLA JIT is extremely non-obvious. + * + *

Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered + * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */ private void makeJit() { try (PointerScope scope = new PointerScope()) { @@ -501,8 +501,8 @@ private void makeJit() { BytePointer trueValue = new BytePointer(bytes); TF_Status status1 = TF_Status.newStatus(); - TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, - status1); + TF_FunctionSetAttrValueProto( + nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1); status1.throwExceptionIfNotOK(); TF_Status status2 = TF_Status.newStatus(); @@ -528,7 +528,6 @@ private static boolean dataTypesMatch(List a, List b) { return true; } - private static TF_Operation outputHandle(Operand operand) { if (operand == null) { throw new NullPointerException("Can't get output handle for null operand"); @@ -564,37 +563,52 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) Reference ref = graph.ref()) { TF_Status status = TF_Status.newStatus(); - List> inputs = signature.getInputs().entrySet().stream() - .map((x) -> CallableFunction.validateDescription(x.getValue(), graph, x.getKey(), "Input")) - .collect(Collectors.toList()); - - List> outputs = signature.getOutputs().entrySet().stream() - .map((x) -> CallableFunction.validateDescription(x.getValue(), graph, x.getKey(), "Output")) - .collect(Collectors.toList()); - - List ops = new ArrayList<>( - graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs))); + List> inputs = + signature.getInputs().entrySet().stream() + .map( + (x) -> + CallableFunction.validateDescription( + x.getValue(), graph, x.getKey(), "Input")) + .collect(Collectors.toList()); + + List> outputs = + signature.getOutputs().entrySet().stream() + .map( + (x) -> + CallableFunction.validateDescription( + x.getValue(), graph, x.getKey(), "Output")) + .collect(Collectors.toList()); + + List ops = + new ArrayList<>(graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs))); inputs.forEach(input -> ops.remove((GraphOperation) input.op())); - ops.forEach(x -> { - if (x.type().equals(Placeholder.OP_NAME) || x.type() - .equals(PlaceholderWithDefault.OP_NAME)) { - throw new IllegalArgumentException( - "Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), " - + "they also depend on \"" + x + "\""); - } - }); + ops.forEach( + x -> { + if (x.type().equals(Placeholder.OP_NAME) + || x.type().equals(PlaceholderWithDefault.OP_NAME)) { + throw new IllegalArgumentException( + "Can't calculate outputs (" + + outputs + + ") from inputs (" + + inputs + + "), " + + "they also depend on \"" + + x + + "\""); + } + }); // Python sometimes has NoOps as outputs Ops tf = Ops.create(graph).withSubScope("functionControlOutputs"); for (int i = 0; i < outputs.size(); i++) { Operand output = outputs.get(i); if (output.op().numOutputs() < 1) { - Operand realOutput = tf - .withControlDependencies(Collections.singletonList(output)) - .withName(output.op().name() + "_control") - .constant(false); + Operand realOutput = + tf.withControlDependencies(Collections.singletonList(output)) + .withName(output.op().name() + "_control") + .constant(false); ops.add((GraphOperation) realOutput.op()); outputs.set(i, realOutput); } @@ -605,27 +619,29 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) operations.put(i, ops.get(i).getUnsafeNativeHandle()); } - TF_Function handle = TF_GraphToFunction( - ref.nativeHandle(), - new BytePointer(signature.key()), - (byte) 1, - ops.size(), - operations, - inputs.size(), - resolveToOutput(graph, inputs), - outputs.size(), - resolveToOutput(graph, outputs), - null, - null, - new BytePointer(signature.methodName() != null ? signature.methodName() - : "Method " + signature.key()), - status - ); + TF_Function handle = + TF_GraphToFunction( + ref.nativeHandle(), + new BytePointer(signature.key()), + (byte) 1, + ops.size(), + operations, + inputs.size(), + resolveToOutput(graph, inputs), + outputs.size(), + resolveToOutput(graph, outputs), + null, + null, + new BytePointer( + signature.methodName() != null + ? signature.methodName() + : "Method " + signature.key()), + status); handle.withDeallocator(); status.throwExceptionIfNotOK(); - return new ConcreteFunction(signature, new NativeFunction(handle), - graph.getNativeFunctions(scope)); + return new ConcreteFunction( + signature, new NativeFunction(handle), graph.getNativeFunctions(scope)); } } -} \ No newline at end of file +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 98bc59abaaa..fd9c436c251 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -93,7 +93,8 @@ public EagerOperationBuilder addInputList(Output[] inputs) { @Override public OperationBuilder addControlInput(Operation control) { - // No-op. Any operations passed to this method will already be evaluated (b/c eager evaluation). + // No-op. Any operations passed to this method will already be evaluated (b/c eager + // evaluation). return this; } @@ -235,9 +236,13 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { session.attachFunction(fn); } - setAttrFunctionList(opHandle, session.nativeHandle(), name, Arrays.stream(value) - .map(ConcreteFunction::getNativeFunctionName) - .collect(Collectors.toList())); + setAttrFunctionList( + opHandle, + session.nativeHandle(), + name, + Arrays.stream(value) + .map(ConcreteFunction::getNativeFunctionName) + .collect(Collectors.toList())); return this; } @@ -248,9 +253,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { private final String type; private final String name; - /** - * This value should be >= to the maximum number of outputs in any op - */ + /** This value should be >= to the maximum number of outputs in any op */ private static final int MAX_OUTPUTS_PER_OP = 1000; private static void requireOp(TFE_Op handle) { @@ -292,7 +295,8 @@ private static TFE_TensorHandle[] execute(TFE_Op opHandle, EagerSession session) requireOp(opHandle); try (PointerScope scope = new PointerScope()) { IntPointer numRetvals = new IntPointer(1).put(MAX_OUTPUTS_PER_OP); - PointerPointer retvals = new PointerPointer(MAX_OUTPUTS_PER_OP); + PointerPointer retvals = + new PointerPointer(MAX_OUTPUTS_PER_OP); TF_Status status = TF_Status.newStatus(); TFE_Execute(opHandle, retvals, numRetvals, status); status.throwExceptionIfNotOK(); @@ -319,7 +323,8 @@ private static void addInput(TFE_Op opHandle, TFE_TensorHandle tensorHandle) { private static void addInputList(TFE_Op opHandle, TFE_TensorHandle[] tensorHandles) { requireOp(opHandle); try (PointerScope scope = new PointerScope()) { - PointerPointer tensorPointers = new PointerPointer(tensorHandles.length); + PointerPointer tensorPointers = + new PointerPointer(tensorHandles.length); for (int i = 0; i < tensorHandles.length; ++i) { requireTensorHandle(tensorHandles[i]); tensorPointers.put(i, tensorHandles[i]); @@ -388,7 +393,8 @@ private static void setAttrBool(TFE_Op opHandle, String name, boolean value) { private static void setAttrBoolList(TFE_Op opHandle, String name, boolean[] values) { requireOp(opHandle); try (PointerScope scope = new PointerScope()) { - TFE_OpSetAttrBoolList(opHandle, name, new BytePointer(new BooleanPointer(values)), values.length); + TFE_OpSetAttrBoolList( + opHandle, name, new BytePointer(new BooleanPointer(values)), values.length); } } @@ -433,8 +439,13 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes shapesPointer.position(shapesPointer.position() + numDims[i] * 8); } TF_Status status = TF_Status.newStatus(); - TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims), - numDims.length, status); + TFE_OpSetAttrShapeList( + opHandle, + new BytePointer(name), + shapesPointers, + new IntPointer(numDims), + numDims.length, + status); } } @@ -445,8 +456,8 @@ private static void setAttrFunctionName(TFE_Op opHandle, String attrName, String } } - private static void setAttrFunctionList(TFE_Op opHandle, TFE_Context context, String attrName, - List functionNames) { + private static void setAttrFunctionList( + TFE_Op opHandle, TFE_Context context, String attrName, List functionNames) { requireOp(opHandle); requireContext(context); try (PointerScope scope = new PointerScope()) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index a2c87285e9a..84fe7675c40 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -293,11 +293,14 @@ public void attachFunction(ConcreteFunction function) { TFE_ContextAddFunction(nativeHandle, function.nativeHandle(), status); status.throwExceptionIfNotOK(); - function.getDependencies().forEach(fn -> { - TF_Status status2 = TF_Status.newStatus(); - TFE_ContextAddFunction(nativeHandle, fn, status2); - status2.throwExceptionIfNotOK(); - }); + function + .getDependencies() + .forEach( + fn -> { + TF_Status status2 = TF_Status.newStatus(); + TFE_ContextAddFunction(nativeHandle, fn, status2); + status2.throwExceptionIfNotOK(); + }); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 489429ef49a..1dd4dde9711 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -391,11 +391,14 @@ public void attachFunction(ConcreteFunction function) { TF_GraphCopyFunction(ref.nativeHandle(), function.nativeHandle(), null, status); status.throwExceptionIfNotOK(); - function.getDependencies().forEach(x -> { - TF_Status status2 = TF_Status.newStatus(); - TF_GraphCopyFunction(ref.nativeHandle(), x, null, status2); - status2.throwExceptionIfNotOK(); - }); + function + .getDependencies() + .forEach( + x -> { + TF_Status status2 = TF_Status.newStatus(); + TF_GraphCopyFunction(ref.nativeHandle(), x, null, status2); + status2.throwExceptionIfNotOK(); + }); } } @@ -431,12 +434,12 @@ synchronized List getNativeFunctions(PointerScope outerScope) { } /** - * Get the function attached to the graph with the given native name. Returns {@code null} if - * none found. + * Get the function attached to the graph with the given native name. Returns {@code null} if none + * found. * - * @param key the name of the native function. Note that this may include an argument hash. + * @param key the name of the native function. Note that this may include an argument hash. * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct - * name + * name */ public synchronized ConcreteFunction getFunction(String key) { try (Reference ref = ref(); @@ -463,7 +466,8 @@ public synchronized List getFunctions() { PointerScope scope = new PointerScope()) { List funcs = getNativeFunctions(scope); - return funcs.stream().map(x -> ConcreteFunction.fromNativeHandle(x, funcs)) + return funcs.stream() + .map(x -> ConcreteFunction.fromNativeHandle(x, funcs)) .collect(Collectors.toList()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java index 7fc68fa8133..0144dca1e59 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java @@ -1,19 +1,19 @@ /* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; @@ -45,16 +45,12 @@ public NativeFunction(TF_Function nativeHandle) { this.nativeHandle = nativeHandle; } - /** - * Get the native handle. No guarantees about liveness are made. - */ + /** Get the native handle. No guarantees about liveness are made. */ public TF_Function getNativeHandle() { return nativeHandle; } - /** - * Get the function's {@link FunctionDef} - */ + /** Get the function's {@link FunctionDef} */ public synchronized FunctionDef getFunctionDef() { if (functionDef == null) { try (PointerScope scope = new PointerScope()) { @@ -75,15 +71,14 @@ public synchronized FunctionDef getFunctionDef() { return functionDef; } - /** - * Get the first-level dependencies of the function. - */ + /** Get the first-level dependencies of the function. */ public synchronized List getDependencies() { if (dependencies == null) { Set deps = new LinkedHashSet<>(); for (NodeDef node : getFunctionDef().getNodeDefList()) { - if (node.getOp().equals(ConcreteFunction.CALL_OP) || node.getOp().equals(ConcreteFunction.STATEFUL_CALL_OP)) { + if (node.getOp().equals(ConcreteFunction.CALL_OP) + || node.getOp().equals(ConcreteFunction.STATEFUL_CALL_OP)) { deps.add(node.getAttrMap().get("f").getFunc().getName()); } } @@ -93,20 +88,18 @@ public synchronized List getDependencies() { return dependencies; } - /** - * Get whether the function is stateful (whether it has stateful ops). - */ + /** Get whether the function is stateful (whether it has stateful ops). */ public synchronized boolean isStateful() { if (stateful == null) { - stateful = getFunctionDef().getSignature().getIsStateful() - || getFunctionDef().getNodeDefList().stream().anyMatch(x -> TensorFlow.isOpStateful(x.getOp())); + stateful = + getFunctionDef().getSignature().getIsStateful() + || getFunctionDef().getNodeDefList().stream() + .anyMatch(x -> TensorFlow.isOpStateful(x.getOp())); } return stateful; } - /** - * Get the name of the function. - */ + /** Get the name of the function. */ public synchronized String getName() { if (name == null) { try (PointerScope scope = new PointerScope()) { @@ -118,8 +111,8 @@ public synchronized String getName() { } synchronized Set getAllDependencies(Collection availableFunctions) { - Map fnMap = availableFunctions.stream() - .collect(Collectors.toMap(NativeFunction::getName, e -> e)); + Map fnMap = + availableFunctions.stream().collect(Collectors.toMap(NativeFunction::getName, e -> e)); Set done = new LinkedHashSet<>(1 + getDependencies().size()); Queue todo = new ArrayDeque<>(1 + getDependencies().size()); @@ -137,7 +130,8 @@ synchronized Set getAllDependencies(Collection avai NativeFunction fn = fnMap.get(dep); if (fn == null) { - throw new IllegalStateException("Function " + dep + " is required, but not present in graph."); + throw new IllegalStateException( + "Function " + dep + " is required, but not present in graph."); } todo.add(fn); @@ -147,7 +141,8 @@ synchronized Set getAllDependencies(Collection avai done.remove(getName()); - return done.stream().map(fnMap::get) + return done.stream() + .map(fnMap::get) .map(NativeFunction::getNativeHandle) .collect(Collectors.toSet()); } @@ -158,5 +153,4 @@ synchronized Set getAllDependencies(Collection avai private List dependencies = null; private Boolean stateful = null; private String name = null; - } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index bbd848c0ceb..4e7f2776710 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -53,22 +53,19 @@ * SavedModelBundle represents a model loaded from storage. * *

The model consists of a description of the computation (a {@link Graph}), a {@link Session} - * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, and a description - * of the model as a MetaGraphDef + * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, + * and a description of the model as a MetaGraphDef * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { public static final String DEFAULT_TAG = "serve"; - /** - * Options for loading a SavedModel. - */ + /** Options for loading a SavedModel. */ public static final class Loader { - /** - * Load a SavedModelBundle with the configured options. - */ + /** Load a SavedModelBundle with the configured options. */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); } @@ -76,8 +73,9 @@ public SavedModelBundle load() { /** * Sets options to use when executing model initialization operations. * - * @param options A RunOptions - * protocol buffer. + * @param options A RunOptions + * protocol buffer. * @return this object */ public Loader withRunOptions(RunOptions options) { @@ -88,8 +86,9 @@ public Loader withRunOptions(RunOptions options) { /** * Set configuration of the Session object created when loading the model. * - * @param configProto A ConfigProto - * protocol buffer. + * @param configProto A ConfigProto + * protocol buffer. * @return this object */ public Loader withConfigProto(ConfigProto configProto) { @@ -122,9 +121,7 @@ private Loader(String exportDir) { private RunOptions runOptions = null; } - /** - * Options for exporting a SavedModel. - */ + /** Options for exporting a SavedModel. */ public static final class Exporter { /** @@ -158,29 +155,30 @@ public Exporter withSession(Session session) { } /** - * Save a concrete function of this model. + * Save a function of this model. * - *

The concrete function carries a signature (i.e. a list of user-friendly input and outputs - * names to a graph) and a valid session to a graph to be saved in the model. + *

The function carries a signature (i.e. a list of user-friendly input and outputs names to + * a graph) and a valid session to a graph to be saved in the model. * *

Note:Eventually, TensorFlow for Java will support the export of functions objects like - * the Python API does but right now, only session-centric models are supported (i.e. models that has a single main - * graph and one or more signatures). These models are compatible with those exported by TensorFlow 1.x or by - * TensorFlow 2.x estimators. - * - *
Therefore, all functions exported in a model should share the same session at the moment - * or an exception will be thrown.
This applies to sessions set via {@link #withSession(Session)} as well, the - * exporter can only even have one session. + * the Python API does but right now, only session-centric models are supported (i.e. models + * that has a single main graph and one or more signatures). These models are compatible with + * those exported by TensorFlow 1.x or by TensorFlow 2.x estimators.
+ * Therefore, all functions exported in a model should share the same session at the moment or + * an exception will be thrown. This applies to sessions set via {@link + * #withSession(Session)} as well, the exporter can only even have one session. * * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object - * @throws IllegalArgumentException if a function with the same name has already been added to the model + * @throws IllegalArgumentException if a function with the same name has already been added to + * the model * @throws UnsupportedOperationException if the session is already set to a different session */ public Exporter withFunction(SessionFunction function) { Signature signature = function.signature(); if (functions.containsKey(signature.key())) { - throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model"); + throw new IllegalArgumentException( + "Function \"" + signature.key() + "\" was already added to the model"); } if (session != null && session != function.session()) { throw new UnsupportedOperationException( @@ -194,9 +192,9 @@ public Exporter withFunction(SessionFunction function) { } /** - * Add a signature to the model. This wraps the signature in a {@link SessionFunction} using the exporter's - * already-set session. As such, either {@link #withSession(Session)} or {@link #withFunction(SessionFunction)} - * must be called before this method. + * Add a signature to the model. This wraps the signature in a {@link SessionFunction} using the + * exporter's already-set session. As such, either {@link #withSession(Session)} or {@link + * #withFunction(SessionFunction)} must be called before this method. * * @throws IllegalStateException if no session has been set */ @@ -223,10 +221,11 @@ public void export() throws IOException { // new ops to the graph for saving and restoring the variables. SaverDef saverDef = graph.saverDef(); - MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder - .setSaverDef(saverDef) - .setGraphDef(graph.toGraphDef()) - .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); + MetaGraphDef.Builder metaGraphDef = + metaGraphDefBuilder + .setSaverDef(saverDef) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); // Make sure saved model directories exist @@ -256,8 +255,9 @@ public void export() throws IOException { } /** - * Load a saved model from an export directory. The model that is being loaded should be created using the Saved Model API. + * Load a saved model from an export directory. The model that is being loaded should be created + * using the Saved Model + * API. * *

This method is a shorthand for: * @@ -302,16 +302,15 @@ public static Exporter exporter(String exportDir) { } /** - * Returns the MetaGraphDef + * Returns the MetaGraphDef * protocol buffer associated with the saved model. */ public MetaGraphDef metaGraphDef() { return metaGraphDef; } - /** - * Returns the graph that describes the computation performed by the model. - */ + /** Returns the graph that describes the computation performed by the model. */ public Graph graph() { return graph; } @@ -325,17 +324,13 @@ public Session session() { return session; } - /** - * Return the signature of all functions available in this saved model. - */ + /** Return the signature of all functions available in this saved model. */ public List signatures() { return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList()); } /** - * Return a {@link ConcreteFunction} corresponding to the function signature. The function may depend on other - * functions in the bundle, which will need to be attached to the execution environment used to call this function (or - * the default eager environment if called with tensors). + * Return a {@link ConcreteFunction} corresponding to the function signature. * *

{@code
    * ConcreteFunction myFunction = savedModelBundle.function("mySignatureKey");
@@ -355,9 +350,7 @@ public SessionFunction function(String signatureKey) {
     return function;
   }
 
-  /**
-   * Get all functions in the bundle.
-   */
+  /** Get all functions in the bundle. */
   public List functions() {
     return new ArrayList<>(functions.values());
   }
@@ -367,9 +360,11 @@ public List functions() {
    *
    * 

The default function selection is done based on the first of the following conditions that * is true: + * *

    - *
  • The function is the only signature available attached to the main graph of this saved model
  • - *
  • The function is mapped to the default signature name, which is "serving_default"
  • + *
  • The function is the only signature available attached to the main graph of this saved + * model + *
  • The function is mapped to the default signature name, which is "serving_default" *
* *

Caller is responsible for closing all returned Tensors. @@ -392,7 +387,8 @@ public Map call(Map arguments) { } /** - * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model bundle. + * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model + * bundle. */ @Override public void close() { @@ -405,18 +401,20 @@ public void close() { private final MetaGraphDef metaGraphDef; private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, - Map signatures) { + private SavedModelBundle( + Graph graph, Session session, MetaGraphDef metaGraphDef, Map signatures) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; - this.functions = signatures.entrySet().stream() - .collect(Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session))); + this.functions = + signatures.entrySet().stream() + .collect( + Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session))); } /** - * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the - * MetaGraphDef. + * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session + * object, plus the MetaGraphDef. * *

Invoked from the native load method. Takes ownership of the handles. */ @@ -432,12 +430,15 @@ private static SavedModelBundle fromHandle( // no effect. final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); - metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { - if (!functions.containsKey(signatureName)) { - Signature signature = new Signature(signatureName, signatureDef); - functions.put(signatureName, signature); - } - }); + metaGraphDef + .getSignatureDefMap() + .forEach( + (signatureName, signatureDef) -> { + if (!functions.containsKey(signatureName)) { + Signature signature = new Signature(signatureName, signatureDef); + functions.put(signatureName, signature); + } + }); return new SavedModelBundle(graph, session, metaGraphDef, functions); } @@ -460,14 +461,22 @@ private static SavedModelBundle load( // load the session TF_Graph graph = TF_NewGraph(); TF_Buffer metagraphDef = TF_Buffer.newBuffer(); - TF_Session session = TF_LoadSessionFromSavedModel( - opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags), - tags.length, graph, metagraphDef, status); + TF_Session session = + TF_LoadSessionFromSavedModel( + opts, + runOpts, + new BytePointer(exportDir), + new PointerPointer(tags), + tags.length, + graph, + metagraphDef, + status); status.throwExceptionIfNotOK(); // handle the result try { - bundle = fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer())); + bundle = + fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer())); } catch (InvalidProtocolBufferException e) { throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index aa2767a9352..d5adae161c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -90,9 +90,11 @@ public Session(Graph g) { * Construct a new session with the associated {@link Graph} and configuration options. * * @param g The {@link Graph} the created Session will operate on. - * @param config Configuration parameters for the session specified as a ConfigProto - * protocol buffer. - * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto protocol buffer. + * @param config Configuration parameters for the session specified as a ConfigProto + * protocol buffer. + * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto + * protocol buffer. */ public Session(Graph g, ConfigProto config) { graph = g; @@ -105,9 +107,7 @@ public Session(Graph g, ConfigProto config) { } } - /** - * Wrap an existing session with the associated {@link Graph}. - */ + /** Wrap an existing session with the associated {@link Graph}. */ Session(Graph g, TF_Session nativeHandle) { graph = g; this.nativeHandle = nativeHandle; @@ -145,20 +145,22 @@ public void close() { * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * *

A Runner runs the necessary graph fragments to execute every {@link Operation} required to - * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call allows callers to - * override the value of {@link Tensor Tensors} in the graph by substituting the provided {@link Tensor Tensors} for - * the outputs of the operations provided to {@link #feed(String, int, Tensor)}. + * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call + * allows callers to override the value of {@link Tensor Tensors} in the graph by substituting the + * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link + * #feed(String, int, Tensor)}. */ public final class Runner { /** * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. * - * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code - * feed(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * feed(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} - * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a + * shorthand for {@code feed(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * feed(operation_name, output_index)}. These colon-separated names are commonly used in the + * {@code SignatureDef} protocol buffer messages that are included in {@link + * SavedModelBundle#metaGraphDef()}. * @param t the tensor substituting the operation * @return this session runner * @throws IllegalArgumentException if no output exists with the provided name @@ -168,8 +170,8 @@ public Runner feed(String operation, Tensor t) { } /** - * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} for the value it - * produces. + * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} + * for the value it produces. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one {@code t} is being provided for. @@ -188,7 +190,8 @@ public Runner feed(String operation, int index, Tensor t) { } /** - * Use {@code t} instead of the Tensor referred to by executing the operation referred to by {@code operand}. + * Use {@code t} instead of the Tensor referred to by executing the operation referred to by + * {@code operand}. * * @param operand the node in the graph representing the operation to substitute * @param t the tensor substituting the operation @@ -196,8 +199,12 @@ public Runner feed(String operation, int index, Tensor t) { */ public Runner feed(Operand operand, Tensor t) { if (operand.env() != graph) { - throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " + - (operand.env().isEager() ? "an eager session" : "a different graph") + "."); + throw new IllegalStateException( + "Can't feed value for operand " + + operand + + ", it is from " + + (operand.env().isEager() ? "an eager session" : "a different graph") + + "."); } inputs.add(operand.asOutput()); @@ -208,13 +215,14 @@ public Runner feed(Operand operand, Tensor t) { /** * Make {@link #run()} return the output of {@code operation}. * - * If the output is a resource variable, will fetch the value. + *

If the output is a resource variable, will fetch the value. * - * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code - * fetch(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} - * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a + * shorthand for {@code fetch(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * fetch(operation_name, output_index)}. These colon-separated names are commonly used in + * the {@code SignatureDef} protocol buffer messages that are included in {@link + * SavedModelBundle#metaGraphDef()}. * @return this session runner * @throws IllegalArgumentException if no output exists with the provided name */ @@ -225,7 +233,7 @@ public Runner fetch(String operation) { /** * Make {@link #run()} return the {@code index}-th output of {@code operation}. * - * If the output is a resource variable, will fetch the value. + *

If the output is a resource variable, will fetch the value. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. @@ -243,15 +251,19 @@ public Runner fetch(String operation, int index) { /** * Makes {@link #run()} return the Tensor referred to by {@code output}. * - * If {@code output} is a resource variable, will fetch the value. + *

If {@code output} is a resource variable, will fetch the value. * * @param output the node to fetch the tensor from * @return this session runner */ public Runner fetch(Output output) { if (output.env() != graph) { - throw new IllegalStateException("Can't fetch output " + output + ", it is from " + - (output.env().isEager() ? "an eager session" : "a different graph") + "."); + throw new IllegalStateException( + "Can't fetch output " + + output + + ", it is from " + + (output.env().isEager() ? "an eager session" : "a different graph") + + "."); } if (output.dataType() == DataType.DT_RESOURCE) { @@ -276,8 +288,11 @@ public Runner fetch(Output output) { } if (read == null) { - read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read") - .readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); + read = + Ops.create(graph) + .withSubScope("session_reads") + .withName(output.op().name() + "_read") + .readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); } outputs.add(read.asOutput()); @@ -290,7 +305,7 @@ public Runner fetch(Output output) { /** * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. * - * If {@code operand} is a resource variable, will fetch the value. + *

If {@code operand} is a resource variable, will fetch the value. * * @param operand the node to fetch the tensor from, as an operand * @return this session runner @@ -300,7 +315,8 @@ public Runner fetch(Operand operand) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. + * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor + * Tensors}. * * @param operation the string name of the operation to execute * @return this session runner @@ -311,7 +327,8 @@ public Runner addTarget(String operation) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. + * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor + * Tensors}. * * @param operation the operation to execute * @return this session runner @@ -320,8 +337,12 @@ public Runner addTarget(String operation) { */ public Runner addTarget(Operation operation) { if (operation.env() != graph) { - throw new IllegalStateException("Can't target operation " + operation + ", it is from " + - (operation.env().isEager() ? "an eager session" : "a different graph") + "."); + throw new IllegalStateException( + "Can't target operation " + + operation + + ", it is from " + + (operation.env().isEager() ? "an eager session" : "a different graph") + + "."); } targets.add((GraphOperation) operation); return this; @@ -341,7 +362,8 @@ public Runner addTarget(Op op) { * Set options (typically for debugging) for this run. * *

The options are presented as a RunOptions protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions + * protocol buffer. * * @param options a {@code RunOptions} proto * @return this session runner @@ -355,11 +377,13 @@ public Runner setOptions(RunOptions options) { * Execute the graph fragments necessary to compute all requested fetches. * *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up resources. + * the caller must call {@link Tensor#close} on all elements of the returned list to free up + * resources. * *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in SessionTest.java), and - * (b) Evaluate whether the return value should be a list, or maybe a {@code Map}? + * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in + * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a + * {@code Map}? * *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to * extract output tensors in a type-safe way. @@ -374,7 +398,8 @@ public List run() { * Execute graph fragments to compute requested fetches and return metadata about the run. * *

This is exactly like {@link #run()}, but in addition to the requested Tensors, also - * returns metadata about the graph execution in the form of a RunMetadata + * returns metadata about the graph execution in the form of a RunMetadata * protocol buffer. * * @return list of resulting tensors fetched by this session runner, with execution metadata @@ -475,9 +500,7 @@ public void close() { private RunOptions runOptions = null; } - /** - * Create a Runner to execute graph operations and evaluate Tensors. - */ + /** Create a Runner to execute graph operations and evaluate Tensors. */ public Runner runner() { return new Runner(); } @@ -546,14 +569,15 @@ public void runInit() { * mymodel/myvariables/variables, then the generated files will be located under * mymodel/myvariables and named variables.data-*-of-* * - *

Note that this method might alter the underlying graph if it is the first time that one - * of its sessions is saved, see {@link Graph#saverDef()} for more details. + *

Note that this method might alter the underlying graph if it is the first time that one of + * its sessions is saved, see {@link Graph#saverDef()} for more details. * * @param prefix prefix to the variable files to save */ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getSaveTensorName()) + runner() + .addTarget(saverDef.getSaveTensorName()) .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) .run(); } @@ -561,19 +585,20 @@ public void save(String prefix) { /** * Restore the actual state of the variables of this session's graph. * - *

{@code prefix} is the path where the files containing the variables state live, - * followed by the filename prefix. For example, if {@code prefix} is set to - * mymodel/myvariables/variables, then the files are loaded from - * mymodel/myvariables and named variables.data-*-of-* + *

{@code prefix} is the path where the files containing the variables state live, followed by + * the filename prefix. For example, if {@code prefix} is set to + * mymodel/myvariables/variables, then the files are loaded from mymodel/myvariables + * and named variables.data-*-of-* * - *

Note that this method might alter the underlying graph if it is the first time that one - * of its sessions is saved, see {@link Graph#saverDef()} for more details. + *

Note that this method might alter the underlying graph if it is the first time that one of + * its sessions is saved, see {@link Graph#saverDef()} for more details. * * @param prefix prefix to restore from */ public void restore(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getRestoreOpName()) + runner() + .addTarget(saverDef.getRestoreOpName()) .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) .run(); } @@ -585,16 +610,15 @@ public void restore(String prefix) { */ public static final class Run { - /** - * Tensors from requested fetches. - */ + /** Tensors from requested fetches. */ public List outputs; /** * Metadata about the run. * *

A RunMetadata protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata + * protocol buffer. */ public RunMetadata metadata; } @@ -661,21 +685,22 @@ private static void delete(TF_Session handle) { * * @param handle to the C API TF_Session object (Session.nativeHandle) * @param runOptions A RunOptions protocol buffer, or null - * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" - * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a - * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, - * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values + * that are being "fed" (do not need to be computed) during graph execution. + * inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the + * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that + * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. * @param inputOpHandles (see inputOpIndices) * @param inputOpIndices (see inputTensorHandles) * @param outputOpHandles (see outputOpIndices) - * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The - * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == - * outputOpIndices.length. - * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose output will not be - * returned + * @param outputOpIndices together with outputOpHandles identifies the set of values that should + * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is + * required that outputOpHandles.length == outputOpIndices.length. + * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose + * output will not be returned * @param wantRunMetadata indicates whether metadata about this execution should be returned. - * @param outputTensors will be filled in with tensors to the outputs requested. It is required that outputs.length == - * outputOpHandles.length. + * @param outputTensors will be filled in with tensors to the outputs requested. It is required + * that outputs.length == outputOpHandles.length. * @return if wantRunMetadata is true, a RunMetadata protocol buffer, false otherwise. */ private static RunMetadata run( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 658dbff205d..8780543a1e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -27,28 +27,22 @@ import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** - * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among - * other useful metadata. + * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, + * among other useful metadata. */ -public class Signature { +public class Signature { /** The default signature key, when not provided */ public static final String DEFAULT_KEY = "serving_default"; public static class TensorDescription { - /** - * The name of the tensor's operand in the graph - */ + /** The name of the tensor's operand in the graph */ public final String name; - /** - * The data type of the tensor - */ + /** The data type of the tensor */ public final DataType dataType; - /** - * The shape of the tensor - */ + /** The shape of the tensor */ public final Shape shape; public TensorDescription(DataType dataType, Shape shape, String name) { @@ -58,9 +52,7 @@ public TensorDescription(DataType dataType, Shape shape, String name) { } } - /** - * Builds a new function signature. - */ + /** Builds a new function signature. */ public static class Builder { /** @@ -90,7 +82,8 @@ public Builder key(String key) { */ public Builder input(String inputName, Operand input) { if (signatureBuilder.containsInputs(inputName)) { - throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input"); + throw new IllegalArgumentException( + "\"" + inputName + "\" is already being mapped to another input"); } signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); return this; @@ -106,7 +99,8 @@ public Builder input(String inputName, Operand input) { */ Builder input(String inputName, TensorInfo input) { if (signatureBuilder.containsInputs(inputName)) { - throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input"); + throw new IllegalArgumentException( + "\"" + inputName + "\" is already being mapped to another input"); } signatureBuilder.putInputs(inputName, input); return this; @@ -122,7 +116,8 @@ Builder input(String inputName, TensorInfo input) { */ public Builder output(String outputName, Operand output) { if (signatureBuilder.containsOutputs(outputName)) { - throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output"); + throw new IllegalArgumentException( + "\"" + outputName + "\" is already being mapped to another output"); } signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); return this; @@ -138,15 +133,16 @@ public Builder output(String outputName, Operand output) { */ Builder output(String outputName, TensorInfo output) { if (signatureBuilder.containsOutputs(outputName)) { - throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output"); + throw new IllegalArgumentException( + "\"" + outputName + "\" is already being mapped to another output"); } signatureBuilder.putOutputs(outputName, output); return this; } /** - * Provide extensible name information enabling third-party users to mark a signature as supporting a particular - * method + * Provide extensible name information enabling third-party users to mark a signature as + * supporting a particular method * * @param methodName method name or null for none (default) * @return this builder @@ -156,9 +152,7 @@ public Builder methodName(String methodName) { return this; } - /** - * Returns a signature from the provided data. - */ + /** Returns a signature from the provided data. */ public Signature build() { return new Signature(key, signatureBuilder.build()); } @@ -180,44 +174,34 @@ private static TensorInfo toTensorInfo(Output operand) { private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); } - /** - * Returns a new builder for creating a signature - */ + /** Returns a new builder for creating a signature */ public static Builder builder() { return new Builder(); } - /** - * Return the key of this signature - */ + /** Return the key of this signature */ public String key() { return key; } - /** - * Returns the method name of this signature (e.g. as exposed by TF serving) or null if none - */ + /** Returns the method name of this signature (e.g. as exposed by TF serving) or null if none */ public String methodName() { return signatureDef.getMethodName().isEmpty() ? null : signatureDef.getMethodName(); } - /** - * Returns the names of the inputs in this signature - */ + /** Returns the names of the inputs in this signature */ public Set inputNames() { return signatureDef.getInputsMap().keySet(); } - /** - * Returns the names of the outputs in this signature - */ + /** Returns the names of the outputs in this signature */ public Set outputNames() { return signatureDef.getOutputsMap().keySet(); } @Override public String toString() { - StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n"); + StringBuilder strBuilder = new StringBuilder("Signature for \"" + key + "\":\n"); if (!methodName().isEmpty()) { strBuilder.append("\tMethod: \"").append(methodName()).append("\"\n"); } @@ -232,18 +216,23 @@ public String toString() { return strBuilder.toString(); } - private Map buildTensorDescriptionMap(Map dataMapIn) { + private Map buildTensorDescriptionMap( + Map dataMapIn) { Map dataTypeMap = new LinkedHashMap<>(); - dataMapIn.forEach((name, info) -> { - long[] tensorDims = info.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); - Shape tensorShape = Shape.of(tensorDims); - dataTypeMap.put(name, new TensorDescription(info.getDtype(), tensorShape, info.getName())); - }); + dataMapIn.forEach( + (name, info) -> { + long[] tensorDims = + info.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); + Shape tensorShape = Shape.of(tensorDims); + dataTypeMap.put( + name, new TensorDescription(info.getDtype(), tensorShape, info.getName())); + }); return Collections.unmodifiableMap(dataTypeMap); } /** - * Returns the names of the inputs in this signature mapped to their expected data type, shape, and operand name + * Returns the names of the inputs in this signature mapped to their expected data type, shape, + * and operand name */ public Map getInputs() { if (inputMap == null) { @@ -253,7 +242,8 @@ public Map getInputs() { } /** - * Returns the names of the outputs in this signature mapped to their expected data type, shape, and operand name + * Returns the names of the outputs in this signature mapped to their expected data type, shape, + * and operand name */ public Map getOutputs() { if (outputMap == null) { @@ -277,19 +267,21 @@ SignatureDef asSignatureDef() { private Map outputMap; private static void printTensorInfo(Map tensorMap, StringBuilder strBuilder) { - tensorMap.forEach((key, tensorInfo) -> { - strBuilder.append("\t\t\"") - .append(key) - .append("\": dtype=") - .append(tensorInfo.getDtype().name()) - .append(", shape=("); - for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) { - strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize()); - if (i < tensorInfo.getTensorShape().getDimCount() - 1) { - strBuilder.append(", "); - } - } - strBuilder.append(")\n"); - }); + tensorMap.forEach( + (key, tensorInfo) -> { + strBuilder + .append("\t\t\"") + .append(key) + .append("\": dtype=") + .append(tensorInfo.getDtype().name()) + .append(", shape=("); + for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) { + strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize()); + if (i < tensorInfo.getTensorShape().getDimCount() - 1) { + strBuilder.append(", "); + } + } + strBuilder.append(")\n"); + }); } -} \ No newline at end of file +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 946a02e0b88..b930da217f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -32,14 +32,10 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.proto.framework.OpList; -/** - * Static utility methods describing the TensorFlow runtime. - */ +/** Static utility methods describing the TensorFlow runtime. */ public final class TensorFlow { - /** - * Returns the version of the underlying TensorFlow runtime. - */ + /** Returns the version of the underlying TensorFlow runtime. */ public static String version() { return TF_Version().getString(); } @@ -47,8 +43,9 @@ public static String version() { /** * All the TensorFlow operations available in this address space. * - * @return A OpList protocol - * buffer, which lists all the available TensorFlow operations. + * @return A OpList + * protocol buffer, which lists all the available TensorFlow operations. */ public static OpList registeredOpList() { TF_Buffer buf = TF_GetAllOpList(); @@ -65,21 +62,24 @@ public static OpList registeredOpList() { public static synchronized boolean isOpStateful(String opType) { if (statefulOps == null) { - statefulOps = registeredOpList().getOpList().stream() - .filter(x -> x.getIsStateful()) - .map(x -> x.getName()) - .collect(Collectors.toSet()); + statefulOps = + registeredOpList().getOpList().stream() + .filter(x -> x.getIsStateful()) + .map(x -> x.getName()) + .collect(Collectors.toSet()); } return statefulOps.contains(opType); } /** - * Load the dynamic library in filename and register the operations and kernels present in that library. + * Load the dynamic library in filename and register the operations and kernels present in that + * library. * * @param filename Path of the dynamic library containing operations and kernels to load. - * @return A OpList protocol - * buffer message defining the operations defined in the library. + * @return A OpList + * protocol buffer message defining the operations defined in the library. * @throws UnsatisfiedLinkError if filename cannot be loaded. */ public static OpList loadLibrary(String filename) { @@ -120,12 +120,9 @@ private static OpList libraryOpList(TF_Library handle) { } } - private TensorFlow() { - } + private TensorFlow() {} - /** - * Load the TensorFlow runtime C library. - */ + /** Load the TensorFlow runtime C library. */ static { try { NativeLibrary.load(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index cb7916d0309..3c691f6f23d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -1,19 +1,19 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Copyright 2019 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.internal.c_api.presets; @@ -28,82 +28,100 @@ import org.bytedeco.javacpp.tools.InfoMap; import org.bytedeco.javacpp.tools.InfoMapper; -/** - * @author Samuel Audet - */ +/** @author Samuel Audet */ @Properties( value = { - @Platform( - value = {"linux", "macosx", "windows"}, - compiler = "cpp11", - include = { - "tensorflow/core/platform/ctstring_internal.h", - "tensorflow/core/platform/ctstring.h", - "tensorflow/core/util/port.h", - "tensorflow/c/tf_attrtype.h", - "tensorflow/c/c_api_macros.h", - "tensorflow/c/tf_datatype.h", - "tensorflow/c/tf_status.h", - "tensorflow/c/tf_tensor.h", - "tensorflow/c/tf_tstring.h", - "tensorflow/c/c_api.h", -// "tensorflow/c/env.h", - "tensorflow/c/kernels.h", - "tensorflow/c/ops.h", - "tensorflow/c/eager/c_api.h" - }, - link = "tensorflow_cc@.2", - preload = {"iomp5", "mklml", "mklml_intel", "tensorflow_framework@.2"}, - preloadresource = "/org/bytedeco/mkldnn/", - resource = {"LICENSE", "THIRD_PARTY_TF_JNI_LICENSES"} - ), - @Platform( - value = "windows", - preload = { - "api-ms-win-crt-locale-l1-1-0", "api-ms-win-crt-string-l1-1-0", "api-ms-win-crt-stdio-l1-1-0", - "api-ms-win-crt-math-l1-1-0", - "api-ms-win-crt-heap-l1-1-0", "api-ms-win-crt-runtime-l1-1-0", "api-ms-win-crt-convert-l1-1-0", - "api-ms-win-crt-environment-l1-1-0", - "api-ms-win-crt-time-l1-1-0", "api-ms-win-crt-filesystem-l1-1-0", "api-ms-win-crt-utility-l1-1-0", - "api-ms-win-crt-multibyte-l1-1-0", - "api-ms-win-core-string-l1-1-0", "api-ms-win-core-errorhandling-l1-1-0", - "api-ms-win-core-timezone-l1-1-0", "api-ms-win-core-file-l1-1-0", - "api-ms-win-core-namedpipe-l1-1-0", "api-ms-win-core-handle-l1-1-0", "api-ms-win-core-file-l2-1-0", - "api-ms-win-core-heap-l1-1-0", - "api-ms-win-core-libraryloader-l1-1-0", "api-ms-win-core-synch-l1-1-0", - "api-ms-win-core-processthreads-l1-1-0", - "api-ms-win-core-processenvironment-l1-1-0", "api-ms-win-core-datetime-l1-1-0", - "api-ms-win-core-localization-l1-2-0", - "api-ms-win-core-sysinfo-l1-1-0", "api-ms-win-core-synch-l1-2-0", "api-ms-win-core-console-l1-1-0", - "api-ms-win-core-debug-l1-1-0", - "api-ms-win-core-rtlsupport-l1-1-0", "api-ms-win-core-processthreads-l1-1-1", - "api-ms-win-core-file-l1-2-0", "api-ms-win-core-profile-l1-1-0", - "api-ms-win-core-memory-l1-1-0", "api-ms-win-core-util-l1-1-0", "api-ms-win-core-interlocked-l1-1-0", - "ucrtbase", - "vcruntime140", "vcruntime140_1", "msvcp140", "concrt140", "vcomp140", "msvcr120", "libiomp5md", - "mklml", "tensorflow_framework" - } - ), - @Platform( - value = "windows-x86", - preloadpath = { - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.CRT/", - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.OpenMP/", - "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x86/" - } - ), - @Platform( - value = "windows-x86_64", - preloadpath = { - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.CRT/", - "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.OpenMP/", - "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x64/" - } - ), - @Platform( - value = {"linux", "macosx", "windows"}, - extension = {"-mkl", "-gpu", "-mkl-gpu"} - ) + @Platform( + value = {"linux", "macosx", "windows"}, + compiler = "cpp11", + include = { + "tensorflow/core/platform/ctstring_internal.h", + "tensorflow/core/platform/ctstring.h", + "tensorflow/core/util/port.h", + "tensorflow/c/tf_attrtype.h", + "tensorflow/c/c_api_macros.h", + "tensorflow/c/tf_datatype.h", + "tensorflow/c/tf_status.h", + "tensorflow/c/tf_tensor.h", + "tensorflow/c/tf_tstring.h", + "tensorflow/c/c_api.h", + // "tensorflow/c/env.h", + "tensorflow/c/kernels.h", + "tensorflow/c/ops.h", + "tensorflow/c/eager/c_api.h" + }, + link = "tensorflow_cc@.2", + preload = {"iomp5", "mklml", "mklml_intel", "tensorflow_framework@.2"}, + preloadresource = "/org/bytedeco/mkldnn/", + resource = {"LICENSE", "THIRD_PARTY_TF_JNI_LICENSES"}), + @Platform( + value = "windows", + preload = { + "api-ms-win-crt-locale-l1-1-0", + "api-ms-win-crt-string-l1-1-0", + "api-ms-win-crt-stdio-l1-1-0", + "api-ms-win-crt-math-l1-1-0", + "api-ms-win-crt-heap-l1-1-0", + "api-ms-win-crt-runtime-l1-1-0", + "api-ms-win-crt-convert-l1-1-0", + "api-ms-win-crt-environment-l1-1-0", + "api-ms-win-crt-time-l1-1-0", + "api-ms-win-crt-filesystem-l1-1-0", + "api-ms-win-crt-utility-l1-1-0", + "api-ms-win-crt-multibyte-l1-1-0", + "api-ms-win-core-string-l1-1-0", + "api-ms-win-core-errorhandling-l1-1-0", + "api-ms-win-core-timezone-l1-1-0", + "api-ms-win-core-file-l1-1-0", + "api-ms-win-core-namedpipe-l1-1-0", + "api-ms-win-core-handle-l1-1-0", + "api-ms-win-core-file-l2-1-0", + "api-ms-win-core-heap-l1-1-0", + "api-ms-win-core-libraryloader-l1-1-0", + "api-ms-win-core-synch-l1-1-0", + "api-ms-win-core-processthreads-l1-1-0", + "api-ms-win-core-processenvironment-l1-1-0", + "api-ms-win-core-datetime-l1-1-0", + "api-ms-win-core-localization-l1-2-0", + "api-ms-win-core-sysinfo-l1-1-0", + "api-ms-win-core-synch-l1-2-0", + "api-ms-win-core-console-l1-1-0", + "api-ms-win-core-debug-l1-1-0", + "api-ms-win-core-rtlsupport-l1-1-0", + "api-ms-win-core-processthreads-l1-1-1", + "api-ms-win-core-file-l1-2-0", + "api-ms-win-core-profile-l1-1-0", + "api-ms-win-core-memory-l1-1-0", + "api-ms-win-core-util-l1-1-0", + "api-ms-win-core-interlocked-l1-1-0", + "ucrtbase", + "vcruntime140", + "vcruntime140_1", + "msvcp140", + "concrt140", + "vcomp140", + "msvcr120", + "libiomp5md", + "mklml", + "tensorflow_framework" + }), + @Platform( + value = "windows-x86", + preloadpath = { + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.CRT/", + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x86/Microsoft.VC140.OpenMP/", + "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x86/" + }), + @Platform( + value = "windows-x86_64", + preloadpath = { + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.CRT/", + "C:/Program Files (x86)/Microsoft Visual Studio 14.0/VC/redist/x64/Microsoft.VC140.OpenMP/", + "C:/Program Files (x86)/Windows Kits/10/Redist/ucrt/DLLs/x64/" + }), + @Platform( + value = {"linux", "macosx", "windows"}, + extension = {"-mkl", "-gpu", "-mkl-gpu"}) }, target = "org.tensorflow.internal.c_api", global = "org.tensorflow.internal.c_api.global.tensorflow") @@ -144,13 +162,29 @@ public void init(ClassProperties properties) { } // Let users enable loading of the full version of MKL - String load = System.getProperty("org.bytedeco.openblas.load", - System.getProperty("org.bytedeco.mklml.load", "")).toLowerCase(); + String load = + System.getProperty( + "org.bytedeco.openblas.load", System.getProperty("org.bytedeco.mklml.load", "")) + .toLowerCase(); int i = 0; if (load.equals("mkl") || load.equals("mkl_rt")) { - String[] libs = {"iomp5", "libiomp5md", "mkl_core", "mkl_avx", "mkl_avx2", "mkl_avx512", "mkl_avx512_mic", - "mkl_def", "mkl_mc", "mkl_mc3", "mkl_intel_lp64", "mkl_intel_thread", "mkl_gnu_thread", "mkl_rt"}; + String[] libs = { + "iomp5", + "libiomp5md", + "mkl_core", + "mkl_avx", + "mkl_avx2", + "mkl_avx512", + "mkl_avx512_mic", + "mkl_def", + "mkl_mc", + "mkl_mc3", + "mkl_intel_lp64", + "mkl_intel_thread", + "mkl_gnu_thread", + "mkl_rt" + }; for (i = 0; i < libs.length; i++) { preloads.add(i, libs[i] + "#" + libs[i]); } @@ -172,29 +206,57 @@ public void init(ClassProperties properties) { if (!Loader.isLoadLibraries() || extension == null || !extension.endsWith("-gpu")) { return; } - String[] libs = {"cudart", "cublasLt", "cublas", "cufft", "curand", "cusolver", "cusparse", "cudnn", "nccl", - "nvrtc", "myelin", "nvinfer", - "cudnn_ops_infer", "cudnn_ops_train", "cudnn_adv_infer", "cudnn_adv_train", "cudnn_cnn_infer", - "cudnn_cnn_train"}; + String[] libs = { + "cudart", + "cublasLt", + "cublas", + "cufft", + "curand", + "cusolver", + "cusparse", + "cudnn", + "nccl", + "nvrtc", + "myelin", + "nvinfer", + "cudnn_ops_infer", + "cudnn_ops_train", + "cudnn_adv_infer", + "cudnn_adv_train", + "cudnn_cnn_infer", + "cudnn_cnn_train" + }; for (String lib : libs) { if (platform.startsWith("linux")) { - lib += lib.startsWith("cudnn") ? "@.8" - : lib.equals("nccl") ? "@.2" - : lib.equals("myelin") ? "@.1" - : lib.equals("nvinfer") ? "@.7" - : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "@.10" - : lib.equals("cudart") ? "@.11.0" - : lib.equals("nvrtc") ? "@.11.0" - : "@.11"; + lib += + lib.startsWith("cudnn") + ? "@.8" + : lib.equals("nccl") + ? "@.2" + : lib.equals("myelin") + ? "@.1" + : lib.equals("nvinfer") + ? "@.7" + : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") + ? "@.10" + : lib.equals("cudart") + ? "@.11.0" + : lib.equals("nvrtc") ? "@.11.0" : "@.11"; } else if (platform.startsWith("windows")) { - lib += lib.startsWith("cudnn") ? "64_8" - : lib.equals("nccl") ? "64_2" - : lib.equals("myelin") ? "64_1" - : lib.equals("nvinfer") ? "64_7" - : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") ? "64_10" - : lib.equals("cudart") ? "64_110" - : lib.equals("nvrtc") ? "64_110_0" - : "64_11"; + lib += + lib.startsWith("cudnn") + ? "64_8" + : lib.equals("nccl") + ? "64_2" + : lib.equals("myelin") + ? "64_1" + : lib.equals("nvinfer") + ? "64_7" + : lib.equals("cufft") || lib.equals("curand") || lib.equals("cusolver") + ? "64_10" + : lib.equals("cudart") + ? "64_110" + : lib.equals("nvrtc") ? "64_110_0" : "64_11"; } else { continue; // no CUDA } @@ -208,51 +270,121 @@ public void init(ClassProperties properties) { } } + @Override public void map(InfoMap infoMap) { - infoMap.put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations()) - .put(new Info("TF_Buffer::data") - .javaText("public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) - .put(new Info("TF_Status").pointerTypes("TF_Status").base("org.tensorflow.internal.c_api.AbstractTF_Status")) - .put(new Info("TF_Buffer").pointerTypes("TF_Buffer").base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) - .put(new Info("TF_Tensor").pointerTypes("TF_Tensor").base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) - .put(new Info("TF_Session").pointerTypes("TF_Session").base("org.tensorflow.internal.c_api.AbstractTF_Session")) - .put(new Info("TF_SessionOptions").pointerTypes("TF_SessionOptions") - .base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) - .put(new Info("TF_Graph").pointerTypes("TF_Graph").base("org.tensorflow.internal.c_api.AbstractTF_Graph")) - .put(new Info("TF_Graph::graph").javaText("public native @MemberGetter @ByRef Graph graph();")) - .put(new Info("TF_Graph::refiner").javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();")) - .put(new Info("TF_Function").pointerTypes("TF_Function") - .base("org.tensorflow.internal.c_api.AbstractTF_Function")) - .put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions") - .base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) - .put(new Info("TF_Operation", "TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell", - "TFE_MonitoringCounter0", "TFE_MonitoringCounter1", "TFE_MonitoringCounter2", - "TFE_MonitoringIntGaugeCell", "TFE_MonitoringStringGaugeCell", "TFE_MonitoringBoolGaugeCell", - "TFE_MonitoringIntGauge0", "TFE_MonitoringIntGauge1", "TFE_MonitoringIntGauge2", - "TFE_MonitoringStringGauge0", "TFE_MonitoringStringGauge1", "TFE_MonitoringStringGauge2", - "TFE_MonitoringBoolGauge0", "TFE_MonitoringBoolGauge1", "TFE_MonitoringBoolGauge2", - "TFE_MonitoringSampler0", "TFE_MonitoringSampler1", "TFE_MonitoringSampler2").purify()) - .put(new Info("TF_Operation::node").javaText("public native @MemberGetter @ByRef Node node();")) - .put(new Info("TFE_MonitoringCounterCell::cell") - .javaText("public native @MemberGetter @ByRef CounterCell cell();")) - .put(new Info("TFE_MonitoringSamplerCell::cell") - .javaText("public native @MemberGetter @ByRef SamplerCell cell();")) - .put(new Info("TFE_MonitoringIntGaugeCell::cell") - .javaText("public native @MemberGetter @ByRef IntGaugeCell cell();")) - .put(new Info("TFE_MonitoringStringGaugeCell::cell") - .javaText("public native @MemberGetter @ByRef StringGaugeCell cell();")) - .put(new Info("TFE_MonitoringBoolGaugeCell::cell") - .javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();")) - .put(new Info("TFE_Context").pointerTypes("TFE_Context") - .base("org.tensorflow.internal.c_api.AbstractTFE_Context")) - .put(new Info("TFE_ContextOptions").pointerTypes("TFE_ContextOptions") - .base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) - .put(new Info("TFE_Context::context").javaText("@MemberGetter public native @ByRef EagerContext context();")) - .put(new Info("TFE_Op").pointerTypes("TFE_Op").base("org.tensorflow.internal.c_api.AbstractTFE_Op")) - .put(new Info("TFE_Op::operation").javaText("@MemberGetter public native @ByRef EagerOperation operation();")) - .put(new Info("TFE_TensorHandle").pointerTypes("TFE_TensorHandle") - .base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) - .put(new Info("TF_ShapeInferenceContextDimValueKnown", - "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)").skip()); + infoMap + .put(new Info("TF_CAPI_EXPORT", "TF_Bool").cppTypes().annotations()) + .put( + new Info("TF_Buffer::data") + .javaText( + "public native @Const Pointer data(); public native TF_Buffer data(Pointer data);")) + .put( + new Info("TF_Status") + .pointerTypes("TF_Status") + .base("org.tensorflow.internal.c_api.AbstractTF_Status")) + .put( + new Info("TF_Buffer") + .pointerTypes("TF_Buffer") + .base("org.tensorflow.internal.c_api.AbstractTF_Buffer")) + .put( + new Info("TF_Tensor") + .pointerTypes("TF_Tensor") + .base("org.tensorflow.internal.c_api.AbstractTF_Tensor")) + .put( + new Info("TF_Session") + .pointerTypes("TF_Session") + .base("org.tensorflow.internal.c_api.AbstractTF_Session")) + .put( + new Info("TF_SessionOptions") + .pointerTypes("TF_SessionOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_SessionOptions")) + .put( + new Info("TF_Graph") + .pointerTypes("TF_Graph") + .base("org.tensorflow.internal.c_api.AbstractTF_Graph")) + .put( + new Info("TF_Graph::graph") + .javaText("public native @MemberGetter @ByRef Graph graph();")) + .put( + new Info("TF_Graph::refiner") + .javaText("public native @MemberGetter @ByRef ShapeRefiner refiner();")) + .put( + new Info("TF_Function") + .pointerTypes("TF_Function") + .base("org.tensorflow.internal.c_api.AbstractTF_Function")) + .put( + new Info("TF_ImportGraphDefOptions") + .pointerTypes("TF_ImportGraphDefOptions") + .base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions")) + .put( + new Info( + "TF_Operation", + "TF_WhileParams", + "TFE_MonitoringCounterCell", + "TFE_MonitoringSamplerCell", + "TFE_MonitoringCounter0", + "TFE_MonitoringCounter1", + "TFE_MonitoringCounter2", + "TFE_MonitoringIntGaugeCell", + "TFE_MonitoringStringGaugeCell", + "TFE_MonitoringBoolGaugeCell", + "TFE_MonitoringIntGauge0", + "TFE_MonitoringIntGauge1", + "TFE_MonitoringIntGauge2", + "TFE_MonitoringStringGauge0", + "TFE_MonitoringStringGauge1", + "TFE_MonitoringStringGauge2", + "TFE_MonitoringBoolGauge0", + "TFE_MonitoringBoolGauge1", + "TFE_MonitoringBoolGauge2", + "TFE_MonitoringSampler0", + "TFE_MonitoringSampler1", + "TFE_MonitoringSampler2") + .purify()) + .put( + new Info("TF_Operation::node") + .javaText("public native @MemberGetter @ByRef Node node();")) + .put( + new Info("TFE_MonitoringCounterCell::cell") + .javaText("public native @MemberGetter @ByRef CounterCell cell();")) + .put( + new Info("TFE_MonitoringSamplerCell::cell") + .javaText("public native @MemberGetter @ByRef SamplerCell cell();")) + .put( + new Info("TFE_MonitoringIntGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef IntGaugeCell cell();")) + .put( + new Info("TFE_MonitoringStringGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef StringGaugeCell cell();")) + .put( + new Info("TFE_MonitoringBoolGaugeCell::cell") + .javaText("public native @MemberGetter @ByRef BoolGaugeCell cell();")) + .put( + new Info("TFE_Context") + .pointerTypes("TFE_Context") + .base("org.tensorflow.internal.c_api.AbstractTFE_Context")) + .put( + new Info("TFE_ContextOptions") + .pointerTypes("TFE_ContextOptions") + .base("org.tensorflow.internal.c_api.AbstractTFE_ContextOptions")) + .put( + new Info("TFE_Context::context") + .javaText("@MemberGetter public native @ByRef EagerContext context();")) + .put( + new Info("TFE_Op") + .pointerTypes("TFE_Op") + .base("org.tensorflow.internal.c_api.AbstractTFE_Op")) + .put( + new Info("TFE_Op::operation") + .javaText("@MemberGetter public native @ByRef EagerOperation operation();")) + .put( + new Info("TFE_TensorHandle") + .pointerTypes("TFE_TensorHandle") + .base("org.tensorflow.internal.c_api.AbstractTFE_TensorHandle")) + .put( + new Info( + "TF_ShapeInferenceContextDimValueKnown", + "TFE_NewTensorHandle(const tensorflow::Tensor&, TF_Status*)") + .skip()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java index 0fe171602e3..87987c78517 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java @@ -1,19 +1,19 @@ /* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ package org.tensorflow.op.core; import java.util.Map; @@ -24,16 +24,13 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; -/** - * Ops for calling {@link ConcreteFunction}. Even though the C API docs say the name of the Op needs to be the name of - * the function, they mean the type. - */ +/** Ops for calling {@link ConcreteFunction}. */ @Operator(name = "call") public abstract class Function { /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The - * inputs and outputs are keyed by the names set in the {@code Signature}. + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. * * @param scope the scope to call the function in * @param arguments the arguments to the call @@ -41,15 +38,14 @@ public abstract class Function { * @see ConcreteFunction#call(Ops, Map) */ @Endpoint - public static Map> call(Scope scope, ConcreteFunction function, - Map> arguments) { + public static Map> call( + Scope scope, ConcreteFunction function, Map> arguments) { return function.call(scope, arguments); } - /** - * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only - * works for functions with a single input and output. + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. Only works for functions with a single input and output. * * @param scope the scope to call the function in * @param argument the argument to the call @@ -57,9 +53,7 @@ public static Map> call(Scope scope, ConcreteFunction functio * @see ConcreteFunction#call(Ops, Operand) */ @Endpoint - public static Operand call(Scope scope, ConcreteFunction function, - Operand argument) { + public static Operand call(Scope scope, ConcreteFunction function, Operand argument) { return function.call(scope, argument); } - } From 3df55b9ba9882913add09b962fe7f99b1892905c Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 21 May 2021 16:57:42 -0700 Subject: [PATCH 26/34] Add session function test, Signature.builder with name Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 8 +- .../main/java/org/tensorflow/Signature.java | 7 ++ .../test/java/org/tensorflow/SessionTest.java | 90 +++++++++++-------- 3 files changed, 62 insertions(+), 43 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index b405b8a39d0..c68b6ee8ff7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1120,8 +1120,8 @@ public Bucketize bucketize(Operand input, List boundar } /** - * Calls the function in an execution environment, adding its graph as a function if it isn't already present. Only - * works for functions with a single input and output. + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. Only works for functions with a single input and output. * * @param argument the argument to the call * @return the output of the function @@ -1132,8 +1132,8 @@ public Operand call(ConcreteFunction function, Operand argument) { } /** - * Calls the function in an execution environment, adding its graph as a function if it isn't already present. The - * inputs and outputs are keyed by the names set in the {@code Signature}. + * Calls the function in an execution environment, adding its graph as a function if it isn't + * already present. The inputs and outputs are keyed by the names set in the {@code Signature}. * * @param arguments the arguments to the call * @return the outputs of the function diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 8780543a1e7..8da71a36cca 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -179,6 +179,13 @@ public static Builder builder() { return new Builder(); } + /** + * Returns a new builder for creating a signature, with the methodName and key set to {@code name} + */ + public static Builder builder(String name) { + return new Builder().methodName(name).key(name); + } + /** Return the key of this signature */ public String key() { return key; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 4223a03ee23..8b1d6c8ce2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -43,18 +43,33 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** - * Unit tests for {@link org.tensorflow.Session}. - */ +/** Unit tests for {@link org.tensorflow.Session}. */ public class SessionTest { + @Test + public void runUsingFunction() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + Ops tf = Ops.create(g); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); + Signature sig = + Signature.builder("sess").input("X", g.output("X")).output("Y", g.output("Y")).build(); + SessionFunction func = s.function(sig); + + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + TInt32 y = (TInt32) func.call(x)) { + assertEquals(31, y.getInt(0, 0)); + } + } + } + @Test public void runUsingOperationNames() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][]{{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); @@ -68,10 +83,10 @@ public void runUsingOperationHandles() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][]{{2}, {3}}); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); @@ -95,12 +110,9 @@ public void runUsingColonSeparatedNames() { } // Feed using colon separated names. try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); - TInt32 fetched = (TInt32) s.runner() - .feed("Split:0", fed) - .feed("Split:1", fed) - .fetch("Add") - .run() - .get(0)) { + TInt32 fetched = + (TInt32) + s.runner().feed("Split:0", fed).feed("Split:1", fed).fetch("Add").run().get(0)) { assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); } } @@ -111,13 +123,14 @@ public void runWithMetadata() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][]{{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) { - Session.Run result = s.runner() - .feed("X", x) - .fetch("Y") - .setOptions(fullTraceRunOptions()) - .runAndFetchMetadata(); + transpose_A_times_X(tf, new int[][] {{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { + Session.Run result = + s.runner() + .feed("X", x) + .fetch("Y") + .setOptions(fullTraceRunOptions()) + .runAndFetchMetadata(); // Sanity check on outputs. AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); assertEquals(1, outputs.size()); @@ -163,8 +176,7 @@ public void failOnUseAfterClose() { @Test public void createWithConfigProto() { try (Graph g = new Graph(); - Session s = new Session(g, singleThreadConfigProto())) { - } + Session s = new Session(g, singleThreadConfigProto())) {} } @Test @@ -219,10 +231,12 @@ public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.withName("x") - .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); - Variable y = tf.withName("y") - .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable x = + tf.withName("x") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = + tf.withName("y") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); try (Session s = new Session(g)) { @@ -234,9 +248,10 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = new AutoCloseableList<>( - restoredSession.runner().fetch("x").fetch("y").run())) { + try (AutoCloseableList oldList = + new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); + AutoCloseableList newList = + new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())) { assertEquals(oldList.get(0), newList.get(0)); assertEquals(oldList.get(1), newList.get(1)); } @@ -265,7 +280,6 @@ public static void testFetchVariable() { try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { assertEquals(2, value.getInt()); } - } } @@ -295,14 +309,11 @@ public static void testFetchVariableReusingRead() { } assertEquals(0, numOperations(g) - ops); - } } private static RunOptions fullTraceRunOptions() { - return RunOptions.newBuilder() - .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) - .build(); + return RunOptions.newBuilder().setTraceLevel(RunOptions.TraceLevel.FULL_TRACE).build(); } private static ConfigProto singleThreadConfigProto() { @@ -313,10 +324,11 @@ private static ConfigProto singleThreadConfigProto() { } private static void transpose_A_times_X(Ops tf, int[][] a) { - tf.withName("Y").linalg.matMul( - tf.withName("A").constant(a), - tf.withName("X").placeholder(TInt32.class), - MatMul.transposeA(true).transposeB(false) - ); + tf.withName("Y") + .linalg + .matMul( + tf.withName("A").constant(a), + tf.withName("X").placeholder(TInt32.class), + MatMul.transposeA(true).transposeB(false)); } } From 99a3403926a620722ccda22b7c9cb7dce2475990 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 21 May 2021 16:58:03 -0700 Subject: [PATCH 27/34] Remove extra synchronization Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Graph.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 1dd4dde9711..f3e712492b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -407,7 +407,7 @@ public void attachFunction(ConcreteFunction function) { * * @param outerScope the pointer scope to attach the functions to. */ - synchronized List getNativeFunctions(PointerScope outerScope) { + List getNativeFunctions(PointerScope outerScope) { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); @@ -441,7 +441,7 @@ synchronized List getNativeFunctions(PointerScope outerScope) { * @return the found {@link ConcreteFunction}, or {@code null} if none were found with the correct * name */ - public synchronized ConcreteFunction getFunction(String key) { + public ConcreteFunction getFunction(String key) { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { List funcs = getNativeFunctions(scope); @@ -461,7 +461,7 @@ public synchronized ConcreteFunction getFunction(String key) { * * @return all functions attached to this graph. */ - public synchronized List getFunctions() { + public List getFunctions() { try (Reference ref = ref(); PointerScope scope = new PointerScope()) { List funcs = getNativeFunctions(scope); From cecf71d1fe9418c284f40125f99a9be311aee86b Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 28 May 2021 12:04:44 -0700 Subject: [PATCH 28/34] Formatting Signed-off-by: Ryan Nett --- .../java/org/tensorflow/CallableFunction.java | 25 ++-- .../java/org/tensorflow/ConcreteFunction.java | 28 ++--- .../org/tensorflow/EagerOperationBuilder.java | 24 ++-- .../org/tensorflow/ExecutionEnvironment.java | 6 +- .../org/tensorflow/GraphOperationBuilder.java | 101 +++++++++------- .../java/org/tensorflow/NativeFunction.java | 31 +++-- .../java/org/tensorflow/OperationBuilder.java | 68 +++++------ .../java/org/tensorflow/SavedModelBundle.java | 24 ++-- .../src/main/java/org/tensorflow/Session.java | 24 ++-- .../java/org/tensorflow/SessionFunction.java | 65 ++++++----- .../main/java/org/tensorflow/Signature.java | 28 ++--- .../main/java/org/tensorflow/TensorFlow.java | 24 ++-- .../internal/c_api/AbstractTF_Function.java | 10 +- .../internal/c_api/presets/tensorflow.java | 26 ++--- .../java/org/tensorflow/op/core/Function.java | 25 ++-- .../org/tensorflow/ConcreteFunctionTest.java | 71 ++++++----- .../tensorflow/EagerOperationBuilderTest.java | 26 ++--- .../tensorflow/GraphOperationBuilderTest.java | 28 ++--- .../org/tensorflow/SavedModelBundleTest.java | 110 +++++++++--------- .../test/java/org/tensorflow/SessionTest.java | 24 ++-- .../org/tensorflow/op/core/FunctionTest.java | 11 +- 21 files changed, 407 insertions(+), 372 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java index bcb5c775f74..21b5bb0422e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java @@ -1,19 +1,18 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -============================================================================== -*/ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import java.util.LinkedHashMap; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 16b9c2817fc..576598c1607 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -1,17 +1,17 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= */ package org.tensorflow; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index fd9c436c251..6c7a322fdef 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_Execute; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index eafc2698789..6f50aeafe98 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -38,9 +38,11 @@ enum Types { OperationBuilder opBuilder(String type, String name); /** - * Attach the function and its dependencies to this execution environment, allowing it to be called. + * Attach the function and its dependencies to this execution environment, allowing it to be + * called. * - * Done automatically in the {@link org.tensorflow.op.Ops#call(ConcreteFunction, java.util.Map)} ops. + *

Done automatically in the {@link org.tensorflow.op.Ops#call(ConcreteFunction, + * java.util.Map)} ops. */ void attachFunction(ConcreteFunction function); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index c8ec73c0346..73f190a2d71 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddControlInput; @@ -63,9 +63,7 @@ import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.NameAttrList; -/** - * An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. - */ +/** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ public final class GraphOperationBuilder implements OperationBuilder { GraphOperationBuilder(Graph graph, String type, String name) { @@ -103,7 +101,8 @@ public GraphOperationBuilder addControlInput(Operation control) { } if (control.env() != graph) { - throw new IllegalArgumentException("Control input " + control + " was from a different graph, can't use."); + throw new IllegalArgumentException( + "Control input " + control + " was from a different graph, can't use."); } Graph.Reference r = graph.ref(); @@ -369,9 +368,12 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { } try (Reference r = graph.ref()) { - setAttrFunctionList(unsafeNativeHandle, name, Arrays.stream(value) - .map(ConcreteFunction::getNativeFunctionName) - .collect(Collectors.toList())); + setAttrFunctionList( + unsafeNativeHandle, + name, + Arrays.stream(value) + .map(ConcreteFunction::getNativeFunctionName) + .collect(Collectors.toList())); } return this; } @@ -426,11 +428,16 @@ private static void addInput(TF_OperationDescription handle, TF_Operation opHand } } - private static void addInputList(TF_OperationDescription handle, TF_Operation[] opHandles, int[] indices) { + private static void addInputList( + TF_OperationDescription handle, TF_Operation[] opHandles, int[] indices) { requireHandle(handle); if (indices.length != opHandles.length) { - throw new IllegalArgumentException("mismatch in number of Operations (" - + opHandles.length + ") and output indices (" + indices.length + ") provided"); + throw new IllegalArgumentException( + "mismatch in number of Operations (" + + opHandles.length + + ") and output indices (" + + indices.length + + ") provided"); } try (PointerScope scope = new PointerScope()) { @@ -444,8 +451,8 @@ private static void addInputList(TF_OperationDescription handle, TF_Operation[] private static void addControlInput(TF_OperationDescription handle, TF_Operation opHandle) { if (opHandle == null || opHandle.isNull()) { - throw new IllegalStateException("control input is not valid, " - + "perhaps the Graph containing it has been closed()?"); + throw new IllegalStateException( + "control input is not valid, " + "perhaps the Graph containing it has been closed()?"); } requireHandle(handle); TF_AddControlInput(handle, opHandle); @@ -491,7 +498,8 @@ private static void setAttrBool(TF_OperationDescription handle, String name, boo TF_SetAttrBool(handle, name, (byte) (value ? 1 : 0)); } - private static void setAttrBoolList(TF_OperationDescription handle, String name, boolean[] value) { + private static void setAttrBoolList( + TF_OperationDescription handle, String name, boolean[] value) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { TF_SetAttrBoolList(handle, name, new BytePointer(new BooleanPointer(value)), value.length); @@ -508,7 +516,8 @@ private static void setAttrTypeList(TF_OperationDescription handle, String name, TF_SetAttrTypeList(handle, name, type, type.length); } - private static void setAttrTensor(TF_OperationDescription handle, String name, TF_Tensor tensorHandle) { + private static void setAttrTensor( + TF_OperationDescription handle, String name, TF_Tensor tensorHandle) { requireHandle(handle); requireTensor(tensorHandle); @@ -519,7 +528,8 @@ private static void setAttrTensor(TF_OperationDescription handle, String name, T } } - private static void setAttrTensorList(TF_OperationDescription handle, String name, TF_Tensor[] tensorHandles) { + private static void setAttrTensorList( + TF_OperationDescription handle, String name, TF_Tensor[] tensorHandles) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { @@ -530,12 +540,14 @@ private static void setAttrTensorList(TF_OperationDescription handle, String nam } TF_Status status = TF_Status.newStatus(); - TF_SetAttrTensorList(handle, new BytePointer(name), tensors.position(0), tensorHandles.length, status); + TF_SetAttrTensorList( + handle, new BytePointer(name), tensors.position(0), tensorHandles.length, status); status.throwExceptionIfNotOK(); } } - private static void setAttrShape(TF_OperationDescription handle, String name, long[] shape, int numDims) { + private static void setAttrShape( + TF_OperationDescription handle, String name, long[] shape, int numDims) { requireHandle(handle); // num_dims and env->GetArrayLength(shape) are assumed to be consistent. @@ -543,7 +555,8 @@ private static void setAttrShape(TF_OperationDescription handle, String name, lo TF_SetAttrShape(handle, name, shape, numDims); } - private static void setAttrShapeList(TF_OperationDescription handle, String name, long[] shapes, int[] numDims) { + private static void setAttrShapeList( + TF_OperationDescription handle, String name, long[] shapes, int[] numDims) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { @@ -553,11 +566,13 @@ private static void setAttrShapeList(TF_OperationDescription handle, String name shapesPointers.put(i, shapesPointer); shapesPointer.position(shapesPointer.position() + numDims[i] * 8); } - TF_SetAttrShapeList(handle, new BytePointer(name), shapesPointers, new IntPointer(numDims), numDims.length); + TF_SetAttrShapeList( + handle, new BytePointer(name), shapesPointers, new IntPointer(numDims), numDims.length); } } - private static void setAttrStringList(TF_OperationDescription handle, String name, byte[][] value) { + private static void setAttrStringList( + TF_OperationDescription handle, String name, byte[][] value) { requireHandle(handle); try (PointerScope scope = new PointerScope()) { @@ -572,23 +587,29 @@ private static void setAttrStringList(TF_OperationDescription handle, String nam } } - private static void setAttrFunctionName(TF_OperationDescription opHandle, String attrName, String functionName) { + private static void setAttrFunctionName( + TF_OperationDescription opHandle, String attrName, String functionName) { requireHandle(opHandle); try (PointerScope scope = new PointerScope()) { TF_SetAttrFuncName(opHandle, attrName, functionName, functionName.length()); } } - private static void setAttrFunctionList(TF_OperationDescription opHandle, String attrName, - List functionNames) { + private static void setAttrFunctionList( + TF_OperationDescription opHandle, String attrName, List functionNames) { requireHandle(opHandle); try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); - AttrValue value = AttrValue.newBuilder().setList(ListValue.newBuilder().addAllFunc( - functionNames.stream() - .map(x -> NameAttrList.newBuilder().setName(x).build()) - .collect(Collectors.toList()) - ).build()).build(); + AttrValue value = + AttrValue.newBuilder() + .setList( + ListValue.newBuilder() + .addAllFunc( + functionNames.stream() + .map(x -> NameAttrList.newBuilder().setName(x).build()) + .collect(Collectors.toList())) + .build()) + .build(); byte[] bytes = value.toByteArray(); TF_SetAttrValueProto(opHandle, attrName, new BytePointer(bytes), bytes.length, status); status.throwExceptionIfNotOK(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java index 0144dca1e59..faab6dbca7b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/NativeFunction.java @@ -1,19 +1,18 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -============================================================================== -*/ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java index e09de39b6c6..569f37c8f4a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import org.tensorflow.ndarray.Shape; @@ -49,7 +49,7 @@ public interface OperationBuilder { * *

The OperationBuilder is not usable after build() returns. */ - Operation build(); + Operation build(); /** * Add the output of another operation as the next input of the operation being built. @@ -57,7 +57,7 @@ public interface OperationBuilder { * @param input {@link Output} supposed to be the input of the operation being built. * @return the OperationBuilder instance for chaining. */ - OperationBuilder addInput(Output input); + OperationBuilder addInput(Output input); /** * Add the outputs of another operation as the next inputs of the operation being built. @@ -65,7 +65,7 @@ public interface OperationBuilder { * @param inputs list of {@link Output} supposed to be the inputs of the operation being built. * @return the OperationBuilder instance for chaining. */ - OperationBuilder addInputList(Output[] inputs); + OperationBuilder addInputList(Output[] inputs); /** * Ensure that the operation does not execute before the control operation does. @@ -80,7 +80,7 @@ public interface OperationBuilder { * @param control operation that must be executed before running this operation. * @return the OperationBuilder instance for chaining. */ - OperationBuilder addControlInput(Operation control); + OperationBuilder addControlInput(Operation control); /** * Set the device requested for computing the operation being built. @@ -88,7 +88,7 @@ public interface OperationBuilder { * @param device the requested device, as a string * @return the OperationBuilder instance for chaining. */ - OperationBuilder setDevice(String device); + OperationBuilder setDevice(String device); /** * Set the string values of an attribute of the operation being built. @@ -97,7 +97,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, String[] value); + OperationBuilder setAttr(String name, String[] value); /** * Set the string value of an attribute of the operation being built. @@ -106,7 +106,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, String value); + OperationBuilder setAttr(String name, String value); /** * Set the byte values of an attribute of the operation being built. @@ -115,7 +115,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, byte[] value); + OperationBuilder setAttr(String name, byte[] value); /** * Set the long value of an attribute of the operation being built. @@ -124,7 +124,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, long value); + OperationBuilder setAttr(String name, long value); /** * Set the long values of an attribute of the operation being built. @@ -133,7 +133,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, long[] value); + OperationBuilder setAttr(String name, long[] value); /** * Set the float value of an attribute of the operation being built. @@ -142,7 +142,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, float value); + OperationBuilder setAttr(String name, float value); /** * Set the float values of an attribute of the operation being built. @@ -151,7 +151,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, float[] value); + OperationBuilder setAttr(String name, float[] value); /** * Set the boolean value of an attribute of the operation being built. @@ -160,7 +160,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, boolean value); + OperationBuilder setAttr(String name, boolean value); /** * Set the boolean values of an attribute of the operation being built. @@ -169,7 +169,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, boolean[] value); + OperationBuilder setAttr(String name, boolean[] value); /** * Set the type value of an attribute of the operation being built. @@ -178,7 +178,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType value); + OperationBuilder setAttr(String name, DataType value); /** * Set the type values of an attribute of the operation being built. @@ -187,7 +187,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, DataType[] value); + OperationBuilder setAttr(String name, DataType[] value); /** * Set the tensor value of an attribute of the operation being built. @@ -196,7 +196,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor value); + OperationBuilder setAttr(String name, Tensor value); /** * Set the tensor values of an attribute of the operation being built. @@ -205,7 +205,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor[] value); + OperationBuilder setAttr(String name, Tensor[] value); /** * Set the shape value of an attribute of the operation being built. @@ -226,8 +226,8 @@ public interface OperationBuilder { OperationBuilder setAttr(String name, Shape[] value); /** - * Set the function value of an attribute of the operation being built. Also attaches the function and dependencies to - * the execution environment. + * Set the function value of an attribute of the operation being built. Also attaches the function + * and dependencies to the execution environment. * * @param name attribute name * @param value attribute value @@ -236,8 +236,8 @@ public interface OperationBuilder { OperationBuilder setAttr(String name, ConcreteFunction value); /** - * Set the function values of an attribute of the operation being built. Also attaches the functions and dependencies - * to the execution environment. + * Set the function values of an attribute of the operation being built. Also attaches the + * functions and dependencies to the execution environment. * * @param name attribute name * @param value attribute value diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 4e7f2776710..84a1a10ca9f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index d5adae161c0..fd0b390bc28 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.Graph.resolveOutputs; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index 1100d5d849f..053b945ed5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -1,5 +1,4 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ============================================================================== + ======================================================================= */ package org.tensorflow; @@ -21,11 +20,13 @@ import java.util.Map; /** - * A callable function backed by a session. All calls of this function will be ran on the same session. + * A callable function backed by a session. All calls of this function will be ran on the same + * session. * - * Does no resource management, the session and all returned tensors are the caller's responsibility. + *

Does no resource management, the session and all returned tensors are the caller's + * responsibility. * - * Does not initialize the session, since it may be shared. + *

Does not initialize the session, since it may be shared. */ public class SessionFunction implements CallableFunction { @@ -36,13 +37,19 @@ public SessionFunction(Signature signature, Session session) { this.signature = signature; this.session = session; - signature.getInputs().forEach((name, description) -> { - CallableFunction.validateDescription(description, session.graph(), name, "Input"); - }); - - signature.getInputs().forEach((name, description) -> { - CallableFunction.validateDescription(description, session.graph(), name, "Output"); - }); + signature + .getInputs() + .forEach( + (name, description) -> { + CallableFunction.validateDescription(description, session.graph(), name, "Input"); + }); + + signature + .getInputs() + .forEach( + (name, description) -> { + CallableFunction.validateDescription(description, session.graph(), name, "Output"); + }); } public static SessionFunction create(Signature signature, Session session) { @@ -70,19 +77,25 @@ public SessionFunction withNewSession(Session session) { @Override public Map call(Map arguments) { Session.Runner runner = session.runner(); - signature.getInputs().forEach((argName, operand) -> { - if (!arguments.containsKey(argName)) { - throw new IllegalArgumentException("No argument found for parameter \"" + argName + "\""); - } - Tensor value = arguments.get(argName); - - if (value == null) { - throw new IllegalArgumentException( - "Can't pass null as an argument to a function. Argument \"" + argName + "\" was null."); - } - - runner.feed(operand.name, value); - }); + signature + .getInputs() + .forEach( + (argName, operand) -> { + if (!arguments.containsKey(argName)) { + throw new IllegalArgumentException( + "No argument found for parameter \"" + argName + "\""); + } + Tensor value = arguments.get(argName); + + if (value == null) { + throw new IllegalArgumentException( + "Can't pass null as an argument to a function. Argument \"" + + argName + + "\" was null."); + } + + runner.feed(operand.name, value); + }); signature.getOutputs().values().forEach(x -> runner.fetch(x.name)); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 8da71a36cca..41fab27e068 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -1,17 +1,17 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= */ package org.tensorflow; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index b930da217f6..23f4c62bc7f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java index 0d021244c6b..a3647b5671d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Function.java @@ -1,5 +1,4 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,9 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ============================================================================== + ======================================================================= */ - package org.tensorflow.internal.c_api; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteFunction; @@ -48,9 +46,7 @@ public TF_Function withDeallocator() { return this.deallocator(new DeleteDeallocator((TF_Function) this)); } - /** - * Calls the deallocator, if registered, otherwise has no effect. - */ + /** Calls the deallocator, if registered, otherwise has no effect. */ public void delete() { deallocate(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index 3c691f6f23d..66dead59967 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -1,20 +1,18 @@ -/* -Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -======================================================================= -*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.internal.c_api.presets; import java.util.List; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java index 87987c78517..255a62e1253 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Function.java @@ -1,19 +1,18 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -============================================================================== -*/ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow.op.core; import java.util.Map; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 5eed335bf32..576ea2e3f95 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -1,17 +1,18 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -32,7 +33,7 @@ public class ConcreteFunctionTest { private static Signature plusFive(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); - Init init = tf.init(); // for native resource management tests + Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } @@ -48,7 +49,11 @@ private static Signature plusFiveMinusTwo(Ops tf) { try (ConcreteFunction plusFive = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction minusTwo = ConcreteFunction.create(ConcreteFunctionTest::minusTwo)) { Operand result = (Operand) minusTwo.call(tf, plusFive.call(tf, input)); - return Signature.builder().key("plusFiveMinusTwo").input("x", input).output("y", result).build(); + return Signature.builder() + .key("plusFiveMinusTwo") + .input("x", input) + .output("y", result) + .build(); } } @@ -66,7 +71,7 @@ public void createFunctionFromGraph() { Signature signature = plusFive(Ops.create(g)); try (ConcreteFunction f = ConcreteFunction.create(signature, g); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + assertEquals(8.0f, ((TFloat32) f.call(x)).getFloat()); } } } @@ -78,7 +83,7 @@ public void createFunctionFromSession() { try (Session s = new Session(g)) { try (ConcreteFunction f = ConcreteFunction.create(signature, s); TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + assertEquals(8.0f, ((TFloat32) f.call(x)).getFloat()); } } } @@ -113,7 +118,8 @@ public void getGraphFunctions() { @Test public void testNestedFunctionEager() { try (EagerSession sess = EagerSession.create(); - ConcreteFunction function = ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) { + ConcreteFunction function = + ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) { Ops tf = Ops.create(sess); Operand a = tf.constant(10f); Operand result = (Operand) function.call(tf, a); @@ -126,7 +132,8 @@ public void testNestedFunctionEager() { @Test public void testNestedFunctionGraph() { try (Graph graph = new Graph(); - ConcreteFunction function = ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) { + ConcreteFunction function = + ConcreteFunction.create(ConcreteFunctionTest::plusFiveMinusTwo)) { Ops tf = Ops.create(graph); Operand a = tf.constant(10f); Operand result = (Operand) function.call(tf, a); @@ -140,11 +147,16 @@ public void testNestedFunctionGraph() { private static Signature square(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Operand output = tf.math.square(input); - return Signature.builder().methodName("square").key("square").input("x", input).output("y", output).build(); + return Signature.builder() + .methodName("square") + .key("square") + .input("x", input) + .output("y", output) + .build(); } // call op gradients are not defined in c++ -// @Test + // @Test public void testGradientsGraph() { try (Graph g = new Graph(); ConcreteFunction square = ConcreteFunction.create(ConcreteFunctionTest::square); @@ -157,12 +169,12 @@ public void testGradientsGraph() { Output y1 = (Output) square.call(tf, y0); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); - Output[] grads0 = g.addGradients(y1, new Output[]{x1}); + Output[] grads0 = g.addGradients(y1, new Output[] {x1}); assertNotNull(grads0); assertEquals(1, grads0.length); assertEquals(DataType.DT_FLOAT, grads0[0].dataType()); - Output[] grads1 = g.addGradients(y2, new Output[]{x1, x2}); + Output[] grads1 = g.addGradients(y2, new Output[] {x1, x2}); assertNotNull(grads1); assertEquals(2, grads1.length); assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); @@ -170,14 +182,15 @@ public void testGradientsGraph() { try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { + AutoCloseableList outputs = + new AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index 8a3d56bd37f..b694e0e5a39 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.fail; @@ -124,7 +124,7 @@ public void setAttrs() { .build(); // bool opBuilder(session, "All", "Bool") - .addInput(tf.constant(new boolean[]{true, true, false}).asOutput()) + .addInput(tf.constant(new boolean[] {true, true, false}).asOutput()) .addInput(tf.constant(0).asOutput()) .setAttr("keep_dims", false) .build(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 66ba2122501..d0e79534d2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -1,18 +1,18 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -96,14 +96,14 @@ public void setAttr() { g.opBuilder("MaxPool", "IntList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) .setAttr("ksize", new long[] {1, 1, 1, 1}) - .setAttr("strides", new long[]{1, 1, 1, 1}) + .setAttr("strides", new long[] {1, 1, 1, 1}) .setAttr("padding", "SAME") .build(); assertTrue(hasNode(g, "IntList")); // list(float) g.opBuilder("FractionalMaxPool", "FloatList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f}) + .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) .build(); assertTrue(hasNode(g, "FloatList")); // Missing tests: float, list(dtype), list(tensor), list(string), list(bool), list(func) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 3501a77b590..542b2f03e67 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -42,9 +42,7 @@ import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.types.TFloat32; -/** - * Unit tests for {@link org.tensorflow.SavedModelBundle}. - */ +/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ public class SavedModelBundleTest { private static final float EPSILON = 1e-7f; @@ -53,9 +51,12 @@ public class SavedModelBundleTest { static { try { - SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); - SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()) - .toString(); + SAVED_MODEL_PATH = + Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); + SAVED_MODEL_PY_PATH = + Paths.get( + SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()) + .toString(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -84,11 +85,12 @@ public void loadNonExistentBundle() { @Test public void loader() { - try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH) - .withTags("serve") - .withConfigProto(sillyConfigProto()) - .withRunOptions(sillyRunOptions()) - .load()) { + try (SavedModelBundle bundle = + SavedModelBundle.loader(SAVED_MODEL_PATH) + .withTags("serve") + .withConfigProto(sillyConfigProto()) + .withRunOptions(sillyRunOptions()) + .load()) { assertNotNull(bundle.session()); assertNotNull(bundle.graph()); assertNotNull(bundle.metaGraphDef()); @@ -103,25 +105,22 @@ public void exportMultipleFunctions() throws IOException { Ops tf = Ops.create(g); Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); Signature f2Signature = buildIdentityGraph(tf, "identity"); - try (Session s = new Session(g);) { + try (Session s = new Session(g); ) { SessionFunction f1 = SessionFunction.create(f1Signature, s); SessionFunction f2 = SessionFunction.create(f2Signature, s); s.runInit(); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2})); TFloat32 t = (TFloat32) f1.call(x)) { reducedSum = t.getFloat(); } - SavedModelBundle.exporter(testFolder.toString()) - .withFunction(f1) - .withFunction(f2) - .export(); + SavedModelBundle.exporter(testFolder.toString()).withFunction(f1).withFunction(f2).export(); } } try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { assertEquals(2, model.signatures().size()); SessionFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2})); TFloat32 t = (TFloat32) f1.call(x)) { assertEquals(reducedSum, t.getFloat(), EPSILON); } @@ -147,7 +146,7 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti Ops tf = Ops.create(g); Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_KEY); - try (Session s = new Session(g);) { + try (Session s = new Session(g); ) { SessionFunction f1 = SessionFunction.create(f1Signature, s); SessionFunction f2 = SessionFunction.create(f2Signature, s); s.runInit(); @@ -166,24 +165,21 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti @Test public void cannotExportOrImportInvalidTags() { - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(null) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(new String[]{"tag", null}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(new String[]{"tag", ""}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(null) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(new String[]{"tag", null}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(new String[]{"tag", ""}) - ); + assertThrows(IllegalArgumentException.class, () -> SavedModelBundle.loader("/").withTags(null)); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.loader("/").withTags(new String[] {"tag", null})); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.loader("/").withTags(new String[] {"tag", ""})); + assertThrows( + IllegalArgumentException.class, () -> SavedModelBundle.exporter("/").withTags(null)); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", null})); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", ""})); } @Test @@ -215,8 +211,11 @@ public void pythonTfFunction() { args.put("dummy", dummy); // TF functions always require an input, so we supply a dummy one here // This test actually checks that resource variables can be loaded correctly. - try (TFloat32 v = (TFloat32) getVariable.call(args) - .get(getVariable.signature().outputNames().iterator().next())) { + try (TFloat32 v = + (TFloat32) + getVariable + .call(args) + .get(getVariable.signature().outputNames().iterator().next())) { assertEquals(2f, v.getFloat()); } } @@ -225,8 +224,9 @@ public void pythonTfFunction() { private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); - Variable y = tf.withName("variable") - .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); + Variable y = + tf.withName("variable") + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); return Signature.builder().input("input", x).output("reducedSum", z).build(); @@ -239,9 +239,7 @@ private static Signature buildIdentityGraph(Ops tf, String signatureKey) { } private static RunOptions sillyRunOptions() { - return RunOptions.newBuilder() - .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) - .build(); + return RunOptions.newBuilder().setTraceLevel(RunOptions.TraceLevel.FULL_TRACE).build(); } private static ConfigProto sillyConfigProto() { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 8b1d6c8ce2c..8a3e64c3336 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -1,18 +1,18 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java index ea06933a7b9..be4386698fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java @@ -1,5 +1,4 @@ -/* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ============================================================================== + ======================================================================= */ package org.tensorflow.op.core; @@ -29,15 +28,13 @@ import org.tensorflow.op.math.Add; import org.tensorflow.types.TFloat32; -/** - * Tests for GraphFunction and it's ops - */ +/** Tests for GraphFunction and it's ops */ public class FunctionTest { private static Signature plusFive(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); - Init init = tf.init(); // for native resource management tests + Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } From 61e5ffd67fa75f7ab546df657f9d3f1298e378dc Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 29 May 2021 17:49:08 -0700 Subject: [PATCH 29/34] New names Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/ConcreteFunction.java | 8 +++----- .../src/main/java/org/tensorflow/SavedModelBundle.java | 4 ++-- .../src/main/java/org/tensorflow/SessionFunction.java | 6 +++--- .../{CallableFunction.java => TensorFunction.java} | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) rename tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/{CallableFunction.java => TensorFunction.java} (99%) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 576598c1607..c62a2636360 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -64,7 +64,7 @@ * Map outputTensorMap = myFunction.call(inputTensorMap); * }

*/ -public class ConcreteFunction implements AutoCloseable, CallableFunction { +public class ConcreteFunction implements AutoCloseable, TensorFunction { /** * Creates a function by building a new graph. @@ -567,16 +567,14 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature) signature.getInputs().entrySet().stream() .map( (x) -> - CallableFunction.validateDescription( - x.getValue(), graph, x.getKey(), "Input")) + TensorFunction.validateDescription(x.getValue(), graph, x.getKey(), "Input")) .collect(Collectors.toList()); List> outputs = signature.getOutputs().entrySet().stream() .map( (x) -> - CallableFunction.validateDescription( - x.getValue(), graph, x.getKey(), "Output")) + TensorFunction.validateDescription(x.getValue(), graph, x.getKey(), "Output")) .collect(Collectors.toList()); List ops = diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 84a1a10ca9f..debd105201c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -341,7 +341,7 @@ public List signatures() { * @return object that can be used to make calls to a function * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ - public SessionFunction function(String signatureKey) { + public TensorFunction function(String signatureKey) { SessionFunction function = functions.get(signatureKey); if (function == null) { throw new IllegalArgumentException( @@ -351,7 +351,7 @@ public SessionFunction function(String signatureKey) { } /** Get all functions in the bundle. */ - public List functions() { + public List functions() { return new ArrayList<>(functions.values()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index 053b945ed5a..a778d964c81 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -28,7 +28,7 @@ * *

Does not initialize the session, since it may be shared. */ -public class SessionFunction implements CallableFunction { +public class SessionFunction implements TensorFunction { private final Signature signature; private final Session session; @@ -41,14 +41,14 @@ public SessionFunction(Signature signature, Session session) { .getInputs() .forEach( (name, description) -> { - CallableFunction.validateDescription(description, session.graph(), name, "Input"); + TensorFunction.validateDescription(description, session.graph(), name, "Input"); }); signature .getInputs() .forEach( (name, description) -> { - CallableFunction.validateDescription(description, session.graph(), name, "Output"); + TensorFunction.validateDescription(description, session.graph(), name, "Output"); }); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java similarity index 99% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java index 21b5bb0422e..0304d786494 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/CallableFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFunction.java @@ -20,7 +20,7 @@ import org.tensorflow.Signature.TensorDescription; /** A function that can be called with tensors. */ -public interface CallableFunction { +public interface TensorFunction { /** Returns the signature of this function */ Signature signature(); From 12582762ab35c8499ccdfd7fa50becb9e5e56deb Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 29 May 2021 17:53:00 -0700 Subject: [PATCH 30/34] Note on SavedModel functions Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/SavedModelBundle.java | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index debd105201c..e92ccba548a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -52,6 +52,9 @@ /** * SavedModelBundle represents a model loaded from storage. * + *

All operations on a loaded bundle, and any functions from it, share the same underlying + * session. The session is initialized when loaded. + * *

The model consists of a description of the computation (a {@link Graph}), a {@link Session} * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, * and a description of the model as a signatures() { * Map outputTensorMap = myFunction.call(session, inputTensorMap); * }

* + * All functions use the bundle's underlying session. + * * @param signatureKey name of the {@code SignatureDef} in the saved model. * @return object that can be used to make calls to a function * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. @@ -350,7 +355,11 @@ public TensorFunction function(String signatureKey) { return function; } - /** Get all functions in the bundle. */ + /** + * Get all functions in the bundle. + * + *

All functions use the bundle's underlying session. + */ public List functions() { return new ArrayList<>(functions.values()); } @@ -369,6 +378,8 @@ public List functions() { * *

Caller is responsible for closing all returned Tensors. * + *

This uses the model's underlying session + * * @param arguments list of input tensors, mapped by their signature name * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default From b1b378e5133088fb23fe0beef25457b7566056d1 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 29 May 2021 17:57:48 -0700 Subject: [PATCH 31/34] Fix tests Signed-off-by: Ryan Nett --- .../test/java/org/tensorflow/SavedModelBundleTest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 542b2f03e67..3344288b2c3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -118,13 +118,13 @@ public void exportMultipleFunctions() throws IOException { } try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { assertEquals(2, model.signatures().size()); - SessionFunction f1 = model.function(Signature.DEFAULT_KEY); + TensorFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2})); TFloat32 t = (TFloat32) f1.call(x)) { assertEquals(reducedSum, t.getFloat(), EPSILON); } - SessionFunction f2 = model.function("identity"); + TensorFunction f2 = model.function("identity"); assertNotNull(f2); try (TFloat32 x = TFloat32.scalarOf(10.0f); TFloat32 t = (TFloat32) f2.call(x)) { @@ -190,7 +190,7 @@ public void pythonTfFunction() { * Test model was created in python * Signature name used for saving 'add', argument names 'a' and 'b' */ - SessionFunction add = bundle.function("add"); + TensorFunction add = bundle.function("add"); Map args = new HashMap<>(); try (TFloat32 a = TFloat32.scalarOf(10.0f); TFloat32 b = TFloat32.scalarOf(15.5f)) { @@ -206,7 +206,7 @@ public void pythonTfFunction() { args.clear(); // variable unwrapping happens in Session, which is used by ConcreteFunction.call - SessionFunction getVariable = bundle.function("get_variable"); + TensorFunction getVariable = bundle.function("get_variable"); try (TFloat32 dummy = TFloat32.scalarOf(1.0f)) { args.put("dummy", dummy); // TF functions always require an input, so we supply a dummy one here From cba0fea302ece30cb7cdb21010a7f452ca7b9ee2 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 29 May 2021 20:42:25 -0700 Subject: [PATCH 32/34] Rename name method Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/ConcreteFunction.java | 8 ++++---- .../main/java/org/tensorflow/EagerOperationBuilder.java | 6 ++---- .../main/java/org/tensorflow/GraphOperationBuilder.java | 6 ++---- .../test/java/org/tensorflow/ConcreteFunctionTest.java | 2 +- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index c62a2636360..3e264e0e25d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -176,10 +176,10 @@ public Signature signature() { } /** - * Get the name of the function. This is what it will show up under in the graph and any exported - * GraphDefs. + * Get the name of the function definition. This is what it will show up under in the graph and + * any exported GraphDefs, and should be used for anything using tensorflow core directly. */ - public String getNativeFunctionName() { + public String getDefinedName() { return nativeFunction.getName(); } @@ -248,7 +248,7 @@ public Map> call(Scope scope, Map> argumen } scope.env().attachFunction(this); - String name = getNativeFunctionName(); + String name = getDefinedName(); String displayName = Scope.isValidOpName(name) ? name : "FunctionCall"; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index 6c7a322fdef..e3283ee2ab3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -226,7 +226,7 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) { @Override public OperationBuilder setAttr(String name, ConcreteFunction value) { session.attachFunction(value); - setAttrFunctionName(opHandle, name, value.getNativeFunctionName()); + setAttrFunctionName(opHandle, name, value.getDefinedName()); return this; } @@ -240,9 +240,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { opHandle, session.nativeHandle(), name, - Arrays.stream(value) - .map(ConcreteFunction::getNativeFunctionName) - .collect(Collectors.toList())); + Arrays.stream(value).map(ConcreteFunction::getDefinedName).collect(Collectors.toList())); return this; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 73f190a2d71..53ab50db4b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -356,7 +356,7 @@ public GraphOperationBuilder setAttr(String name, String[] value) { public OperationBuilder setAttr(String name, ConcreteFunction value) { graph.attachFunction(value); try (Reference r = graph.ref()) { - setAttrFunctionName(unsafeNativeHandle, name, value.getNativeFunctionName()); + setAttrFunctionName(unsafeNativeHandle, name, value.getDefinedName()); } return this; } @@ -371,9 +371,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { setAttrFunctionList( unsafeNativeHandle, name, - Arrays.stream(value) - .map(ConcreteFunction::getNativeFunctionName) - .collect(Collectors.toList())); + Arrays.stream(value).map(ConcreteFunction::getDefinedName).collect(Collectors.toList())); } return this; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 576ea2e3f95..64c33f451fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -105,7 +105,7 @@ public void getGraphFunctions() { Ops tf = Ops.create(g); tf.call(function, tf.constant(3f)); - ConcreteFunction attached = g.getFunction(function.getNativeFunctionName()); + ConcreteFunction attached = g.getFunction(function.getDefinedName()); assertNotNull(attached); try (TFloat32 x = TFloat32.scalarOf(10f); From e67ae1e72f454771429b2820722ceff1d682f1fa Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 30 May 2021 13:19:30 -0700 Subject: [PATCH 33/34] Re-add tests w/ SessionFunction Signed-off-by: Ryan Nett --- .../org/tensorflow/SavedModelBundleTest.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 3344288b2c3..6ac6656775d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -26,10 +26,12 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Ops; @@ -40,6 +42,8 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.types.TFloat32; /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ @@ -139,6 +143,104 @@ public void exportMultipleFunctions() throws IOException { } } + @Test + public void exportFunctionWithVariables() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + FloatNdArray xValue = StdArrays.ndCopyOf(new float[][] {{0, 1, 2}, {3, 4, 5}}); + Shape xyShape = Shape.of(2, 3L); + try (Graph g = new Graph(); + Session session = new Session(g)) { + Ops tf = Ops.create(g); + SessionFunction f = session.function(buildGraphWithVariables(tf, xyShape)); + // Init variable state by running the Init operation directly + session.runInit(); + + // Call the graph and remember the result of computation for later + try (TFloat32 xTensor = TFloat32.tensorOf(xValue); + TFloat32 zTensor = (TFloat32) f.call(xTensor)) { + reducedSum = zTensor.getFloat(); + } + // Save/export the model (which is a single function in this case) + SavedModelBundle.exporter(testFolder.toString()).withFunction(f).export(); + } + assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); + assertTrue( + Files.exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); + assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); + + // Reload the model just saved and validate its data + try (SavedModelBundle savedModel = + SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { + assertNotNull(savedModel.metaGraphDef()); + assertNotNull(savedModel.metaGraphDef().getSaverDef()); + assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); + assertEquals( + Signature.DEFAULT_KEY, + savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); + + TensorFunction function = savedModel.function(Signature.DEFAULT_KEY); + assertNotNull(function); + + Signature signature = function.signature(); + assertNotNull(signature); + assertEquals(1, signature.inputNames().size()); + assertEquals("input", signature.inputNames().iterator().next()); + assertEquals(1, signature.outputNames().size()); + assertEquals("reducedSum", signature.outputNames().iterator().next()); + + SignatureDef signatureDef = signature.asSignatureDef(); + assertEquals(1, signatureDef.getInputsCount()); + assertEquals(1, signatureDef.getOutputsCount()); + + TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); + assertNotNull(inputInfo); + assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); + for (int i = 0; i < xyShape.numDimensions(); ++i) { + assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); + } + + TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); + assertNotNull(outputInfo); + assertEquals(0, outputInfo.getTensorShape().getDimCount()); + + try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { + // Call the saved model function and make sure it returns the same result as before + try (TFloat32 zTensor = (TFloat32) function.call(xTensor)) { + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); + } + // Now call the same function directly from the model + try (TFloat32 zTensor = + (TFloat32) + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); + } + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph(); + Session s1 = new Session(g); + Session s2 = new Session(g)) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + SessionFunction f1 = s1.function(f1Signature); + SessionFunction f2 = s2.function(f2Signature); + s1.runInit(); + s2.runInit(); + try { + SavedModelBundle.exporter(testFolder.toString()).withFunction(f1).withFunction(f2).export(); + fail(); + } catch (UnsupportedOperationException e) { + // as expected + } + } + } + @Test public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); From c09385d5e704585cb6d9b25af39558dfa8c13b86 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 30 May 2021 13:36:07 -0700 Subject: [PATCH 34/34] Helper methods for saving Signed-off-by: Ryan Nett --- .../java/org/tensorflow/SavedModelBundle.java | 37 +++++++++++++++++++ .../java/org/tensorflow/SessionFunction.java | 14 +++++++ .../org/tensorflow/SavedModelBundleTest.java | 7 +--- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index e92ccba548a..3a6433701e6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -194,12 +194,31 @@ public Exporter withFunction(SessionFunction function) { return this; } + /** + * Save multiple functions. Wrapper around {@link #withFunction(SessionFunction)}. All functions + * must have the same session. + * + * @param functions the functions to export + * @return this object + * @throws IllegalArgumentException if a function with the same name has already been added to + * the model + * @throws UnsupportedOperationException if the session is already set to a different session + * @see #withFunction(SessionFunction) + */ + public Exporter withFunctions(SessionFunction... functions) { + for (SessionFunction f : functions) { + withFunction(f); + } + return this; + } + /** * Add a signature to the model. This wraps the signature in a {@link SessionFunction} using the * exporter's already-set session. As such, either {@link #withSession(Session)} or {@link * #withFunction(SessionFunction)} must be called before this method. * * @throws IllegalStateException if no session has been set + * @return this */ public Exporter withSignature(Signature signature) { if (session == null) { @@ -209,6 +228,24 @@ public Exporter withSignature(Signature signature) { return withFunction(session.function(signature)); } + /** + * Add multiple signatures to the model. Wraps {@link #withSignature(Signature)} + * + *

Either {@link #withSession(Session)} or {@link * #withFunction(SessionFunction)} must + * be called before this method, and the session set there will be used for these + * signatures. + * + * @throws IllegalStateException if no session has been set + * @return this + * @see #withSession(Session) + */ + public Exporter withSignatures(Signature... signatures) { + for (Signature s : signatures) { + withSignature(s); + } + return this; + } + /** * Save the model into the export directory. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java index a778d964c81..07bc418ac51 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java @@ -15,6 +15,7 @@ */ package org.tensorflow; +import java.io.IOException; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -56,6 +57,19 @@ public static SessionFunction create(Signature signature, Session session) { return new SessionFunction(signature, session); } + /** + * Save this function using {@link SavedModelBundle}. + * + *

This is identical to calling {@code + * SavedModelBundle.exporter(exportDir).withFunction(this).export()}. + * + * @param exportDir the directory path containing a saved model. + * @throws IOException if saved model or variable state cannot be written on disk + */ + public void save(String exportDir) throws IOException { + SavedModelBundle.exporter(exportDir).withFunction(this).export(); + } + @Override public Signature signature() { return signature; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 6ac6656775d..1561842a689 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -162,7 +162,7 @@ public void exportFunctionWithVariables() throws IOException { reducedSum = zTensor.getFloat(); } // Save/export the model (which is a single function in this case) - SavedModelBundle.exporter(testFolder.toString()).withFunction(f).export(); + f.save(testFolder.toString()); } assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); assertTrue( @@ -253,10 +253,7 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti SessionFunction f2 = SessionFunction.create(f2Signature, s); s.runInit(); try { - SavedModelBundle.exporter(testFolder.toString()) - .withFunction(f1) - .withFunction(f2) - .export(); + SavedModelBundle.exporter(testFolder.toString()).withFunctions(f1, f2).export(); fail(); } catch (IllegalArgumentException e) { // as expected