diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 74f7efb4623..b4ab7753142 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -18,6 +18,7 @@ package org.tensorflow.op; import java.nio.charset.Charset; +import java.util.Arrays; import java.util.List; import java.util.Map; import org.tensorflow.ConcreteFunction; @@ -25,6 +26,7 @@ import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; +import org.tensorflow.Operation; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -365,20 +367,20 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -396,13 +398,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** @@ -8132,6 +8134,33 @@ public Ops withControlDependencies(Iterable controls) { return new Ops(scope.withControlDependencies(controls)); } + /** + * Returns an API that adds operations to the graph with the provided control dependencies. + * + * @see {@link Scope#withControlDependencies(Iterable>)} + */ + public Ops withControlDependencies(Op... controls) { + return withControlDependencies(Arrays.asList(controls)); + } + + /** + * Returns an API that adds operations to the graph with the provided control dependencies. + * + * @see {@link Scope#withControlDependencyOps(Iterable)} + */ + public Ops withControlDependencyOps(Iterable controls) { + return new Ops(scope.withControlDependencyOps(controls)); + } + + /** + * Returns an API that adds operations to the graph with the provided control dependencies. + * + * @see {@link Scope#withControlDependencyOps(Iterable)} + */ + public Ops withControlDependencyOps(Operation... controls) { + return withControlDependencyOps(Arrays.asList(controls)); + } + /** * Returns the current {@link Scope scope} of this API */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 98b6c98abc4..fa431244173 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -541,20 +541,25 @@ public void importGraphDef(GraphDef graphDef, String prefix) throws IllegalArgum }); } - private synchronized void addInitOp() { - if (!newInitializers) { - return; + /** + * Create and return a NoOp that will run all init ops. If {@code required} is false and there are + * no new init ops since the last call, will do nothing and return null. + */ + synchronized GraphOperation addInitOp(boolean required) { + if (!newInitializers && !required) { + return null; } - if (initializers.isEmpty()) { - return; + if (initializers.isEmpty() && !required) { + return null; } baseScope.refreshNames(); OperationBuilder builder = baseScope().withInitScope().opBuilder(NoOp.OP_NAME, INIT_OP_BASE_NAME); initializers.forEach(builder::addControlInput); - builder.build(); + GraphOperation initOp = (GraphOperation) builder.build(); newInitializers = false; + return initOp; } /** @@ -568,7 +573,7 @@ private synchronized void addInitOp() { * @see #importGraphDef(GraphDef, String) */ public GraphDef toGraphDef() { - addInitOp(); + addInitOp(false); synchronized (nativeHandleLock) { return toGraphDef(nativeHandle); } @@ -1239,6 +1244,8 @@ private static Object[] whileLoop( private static SaverDef addVariableSaver(Graph graph) { Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE); + // TODO handle resource variables, too + List varNames = new ArrayList<>(); List> varOutputs = new ArrayList<>(); List> varTypes = new ArrayList<>(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 6445f05d276..cdf6c2bfc5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -41,11 +41,15 @@ import org.tensorflow.internal.c_api.TF_Session; import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.proto.framework.CollectionDef; +import org.tensorflow.proto.framework.CollectionDef.NodeList; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SavedModel; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.util.SaverDef; /** @@ -64,6 +68,27 @@ public class SavedModelBundle implements AutoCloseable { public static final String DEFAULT_TAG = "serve"; + /** Signature used to track Java init ops, for our init scope. */ + private static final String JAVA_INIT_OP_SIGNATURE_KEY = "__saved_model_java_init_op_tracker"; + + /** + * Tensorflow init op tracking signature. Init ops are executed before loading variables, so this + * does not work for us. + */ + private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + + /** + * A backup Tensorflow init op collection key. In TF1, init ops will be stored in collections + * instead of signatures. + */ + private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op"; + + /** An even more legacy init op collection key. */ + private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op"; + + /** The collection where table initializers are stored in some hub models. */ + private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer"; + /** Options for loading a SavedModel. */ public static final class Loader { @@ -260,6 +285,11 @@ public void export() throws IOException { // new ops to the graph for saving and restoring the variables. SaverDef saverDef = graph.saverDef(); + GraphOperation initOp = null; + if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) { + initOp = graph.addInitOp(true); + } + MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder .setSaverDef(saverDef) @@ -267,6 +297,17 @@ public void export() throws IOException { .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); + if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) { + + metaGraphDef.putSignatureDef( + JAVA_INIT_OP_SIGNATURE_KEY, + SignatureDef.newBuilder() + .putOutputs( + JAVA_INIT_OP_SIGNATURE_KEY, + TensorInfo.newBuilder().setName(initOp.name() + ":0").build()) + .build()); + } + // Make sure saved model directories exist Path variableDir = Paths.get(exportDir, "variables"); variableDir.toFile().mkdirs(); @@ -365,7 +406,14 @@ public Session session() { /** Return the signature of all functions available in this saved model. */ public List signatures() { - return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList()); + // the init signatures aren't actual functions, just markers + return functions.values().stream() + .map(SessionFunction::signature) + .filter( + signature -> + !signature.key().equals(INIT_OP_SIGNATURE_KEY) + && !signature.key().equals(JAVA_INIT_OP_SIGNATURE_KEY)) + .collect(Collectors.toList()); } /** @@ -459,6 +507,32 @@ private SavedModelBundle( Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session))); } + private static GraphOperation findInitOp( + Graph graph, Map signatures, Map collections) { + + Signature initSig = signatures.get(INIT_OP_SIGNATURE_KEY); + if (initSig != null) { + return (GraphOperation) + graph.outputOrThrow(initSig.getOutputs().get(INIT_OP_SIGNATURE_KEY).name).op(); + } + + CollectionDef initCollection; + if (collections.containsKey(MAIN_OP_COLLECTION_KEY)) { + initCollection = collections.get(MAIN_OP_COLLECTION_KEY); + } else { + initCollection = collections.get(LEGACY_INIT_OP_COLLECTION_KEY); + } + + if (initCollection != null) { + NodeList nodes = initCollection.getNodeList(); + if (nodes.getValueCount() != 1) { + throw new IllegalArgumentException("Expected exactly one main op in saved model."); + } + return (GraphOperation) graph.outputOrThrow(nodes.getValue(0)).op(); + } + return null; + } + /** * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session * object, plus the MetaGraphDef. @@ -486,6 +560,41 @@ private static SavedModelBundle fromHandle( functions.put(signatureName, signature); } }); + + GraphOperation initOp = findInitOp(graph, functions, metaGraphDef.getCollectionDefMap()); + if (initOp != null) { + graph.registerInitOp(initOp); + } + + // java init ops are marked as ran, since the variable restore will restore any state + // they mutated. + // Technically, init ops should be ran first, then variable restore, but that is not possible + // since TF_Session.loadSessionFromSavedModel does it in reverse order, so we just mark them as + // ran. + if (functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) { + String initOpName = + functions + .get(JAVA_INIT_OP_SIGNATURE_KEY) + .getOutputs() + .get(JAVA_INIT_OP_SIGNATURE_KEY) + .name; + graph.registerInitOp(graph.outputOrThrow(initOpName).op()); + } + + session.setInitialized(); + + if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) { + metaGraphDef + .getCollectionDefMap() + .get(TABLE_INITIALIZERS_COLLECTION_KEY) + .getNodeList() + .getValueList() + .forEach( + node -> { + graph.registerInitOp(graph.operationOrThrow(node)); + }); + } + return new SavedModelBundle(graph, session, metaGraphDef, functions); } @@ -525,6 +634,7 @@ private static SavedModelBundle load( throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e); } } + bundle.session.initialize(); return bundle; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index dfb8b7bbf60..c1fdbf10a17 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -186,13 +186,21 @@ public void close() { public void initialize() { Runner runner = runner(); graph.initializers().stream().filter((x) -> !ranInits.contains(x)).forEach(runner::addTarget); - ranInits.clear(); - ranInits.addAll(graph.initializers()); + setInitialized(); if (!runner.isEmpty()) { runner.runNoInit(); } } + /** + * Set the ran initializers to all initializers in the graph, as if they had been run. Does not + * actually ensure they are ran. + */ + void setInitialized() { + ranInits.clear(); + ranInits.addAll(graph.initializers()); + } + /** * Execute the graph's initializers, regardless of whether the session has been initialized. * @@ -686,8 +694,7 @@ public void restore(String prefix) { .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) .runNoInit(); // TODO better way of doing this, only count as ran assignments to the restored variables. - ranInits.clear(); - ranInits.addAll(graph.initializers()); + setInitialized(); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index ffe4d072bb5..be6f952fb6a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -179,10 +179,8 @@ public void exportFunctionWithVariables() throws IOException { assertEquals("save/control_dependency", saverDef.getSaveTensorName()); assertEquals("save/restore_all", saverDef.getRestoreOpName()); - assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); - assertEquals( - Signature.DEFAULT_KEY, - savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); + assertEquals(2, savedModel.metaGraphDef().getSignatureDefCount()); + assertTrue(savedModel.metaGraphDef().getSignatureDefMap().containsKey(Signature.DEFAULT_KEY)); TensorFunction function = savedModel.function(Signature.DEFAULT_KEY); assertNotNull(function); @@ -269,18 +267,15 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti @Test public void cannotExportOrImportInvalidTags() { - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(null) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.loader("/").withTags(new String[]{"tag", null}) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(null) - ); - assertThrows(IllegalArgumentException.class, () -> - SavedModelBundle.exporter("/").withTags(new String[]{"tag", null}) - ); + assertThrows(IllegalArgumentException.class, () -> SavedModelBundle.loader("/").withTags(null)); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.loader("/").withTags(new String[] {"tag", null})); + assertThrows( + IllegalArgumentException.class, () -> SavedModelBundle.exporter("/").withTags(null)); + assertThrows( + IllegalArgumentException.class, + () -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", null})); } @Test diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java index d2218c7d6e0..38987e81cb3 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java @@ -1,21 +1,22 @@ /* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ package org.tensorflow; +import com.squareup.javapoet.ArrayTypeName; import com.squareup.javapoet.ClassName; import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeName; @@ -50,15 +51,22 @@ public class Names { public static final ClassName RawOp = ClassName.get(OpPackage, "RawOp"); public static final ClassName Operation = ClassName.get(TensorflowPackage, "Operation"); public static final ClassName Operands = ClassName.get(OpPackage, "Operands"); - public static final ClassName OperationBuilder = ClassName.get(TensorflowPackage, "OperationBuilder"); - public static final TypeName IterableOp = ParameterizedTypeName.get(ClassName.get(Iterable.class), Op); + public static final ClassName OperationBuilder = + ClassName.get(TensorflowPackage, "OperationBuilder"); + public static final TypeName IterableOp = + ParameterizedTypeName.get(ClassName.get(Iterable.class), Op); + public static final TypeName IterableOperation = + ParameterizedTypeName.get(ClassName.get(Iterable.class), Operation); + public static final TypeName ArrayOp = ArrayTypeName.of(Op); + public static final TypeName ArrayOperation = ArrayTypeName.of(Operation); public static final ClassName Operand = ClassName.get(TensorflowPackage, "Operand"); public static final ClassName Output = ClassName.get(TensorflowPackage, "Output"); public static final ClassName Shape = ClassName.get(TensorflowPackage + ".ndarray", "Shape"); public static final ClassName Tensor = ClassName.get(TensorflowPackage, "Tensor"); - public static final ClassName ConcreteFunction = ClassName.get(TensorflowPackage, "ConcreteFunction"); + public static final ClassName ConcreteFunction = + ClassName.get(TensorflowPackage, "ConcreteFunction"); public static final ClassName Scope = ClassName.get(OpPackage, "Scope"); public static final TypeName DeviceSpec = ClassName.get(TensorflowPackage, "DeviceSpec"); @@ -69,5 +77,5 @@ public class Names { public static final TypeName EagerSession = ClassName.get(TensorflowPackage, "EagerSession"); public static final TypeName String = ClassName.get(String.class); - + public static final ClassName Arrays = ClassName.get(java.util.Arrays.class); } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java index 2a33483cb7b..70c7bb0a7de 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java @@ -619,6 +619,44 @@ private static TypeSpec buildTopClass(OpsSpec spec) { Names.Scope) .build()); + opsBuilder.addMethod( + MethodSpec.methodBuilder("withControlDependencies") + .addModifiers(Modifier.PUBLIC) + .addParameter(Names.ArrayOp, "controls") + .varargs() + .returns(Names.Ops) + .addStatement("return withControlDependencies($T.asList(controls))", Names.Arrays) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + + "@see {@link $T#withControlDependencies(Iterable>)}\n", + Names.Scope) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withControlDependencyOps") + .addModifiers(Modifier.PUBLIC) + .addParameter(Names.IterableOperation, "controls") + .returns(Names.Ops) + .addStatement("return new Ops(scope.withControlDependencyOps(controls))") + .addJavadoc( + "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + + "@see {@link $T#withControlDependencyOps(Iterable)}\n", + Names.Scope) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withControlDependencyOps") + .addModifiers(Modifier.PUBLIC) + .addParameter(Names.ArrayOperation, "controls") + .varargs() + .returns(Names.Ops) + .addStatement("return withControlDependencyOps($T.asList(controls))", Names.Arrays) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + + "@see {@link $T#withControlDependencyOps(Iterable)}\n", + Names.Scope) + .build()); + opsBuilder.addField( FieldSpec.builder(Names.Scope, "scope") .addModifiers(Modifier.PRIVATE, Modifier.FINAL)