Skip to content

Commit 2d449a6

Browse files
authored
Init exporting and loading (#376)
1 parent 5e3fc49 commit 2d449a6

File tree

7 files changed

+244
-50
lines changed

7 files changed

+244
-50
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
package org.tensorflow.op;
1919

2020
import java.nio.charset.Charset;
21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Map;
2324
import org.tensorflow.ConcreteFunction;
2425
import org.tensorflow.DeviceSpec;
2526
import org.tensorflow.EagerSession;
2627
import org.tensorflow.ExecutionEnvironment;
2728
import org.tensorflow.Operand;
29+
import org.tensorflow.Operation;
2830
import org.tensorflow.ndarray.BooleanNdArray;
2931
import org.tensorflow.ndarray.ByteNdArray;
3032
import org.tensorflow.ndarray.DoubleNdArray;
@@ -365,20 +367,20 @@ public final class Ops {
365367

366368
public final SparseOps sparse;
367369

368-
public final TpuOps tpu;
369-
370370
public final BitwiseOps bitwise;
371371

372+
public final TpuOps tpu;
373+
372374
public final MathOps math;
373375

374376
public final AudioOps audio;
375377

376378
public final SignalOps signal;
377379

378-
public final QuantizationOps quantization;
379-
380380
public final TrainOps train;
381381

382+
public final QuantizationOps quantization;
383+
382384
private final Scope scope;
383385

384386
private Ops(Scope scope) {
@@ -396,13 +398,13 @@ private Ops(Scope scope) {
396398
random = new RandomOps(this);
397399
strings = new StringsOps(this);
398400
sparse = new SparseOps(this);
399-
tpu = new TpuOps(this);
400401
bitwise = new BitwiseOps(this);
402+
tpu = new TpuOps(this);
401403
math = new MathOps(this);
402404
audio = new AudioOps(this);
403405
signal = new SignalOps(this);
404-
quantization = new QuantizationOps(this);
405406
train = new TrainOps(this);
407+
quantization = new QuantizationOps(this);
406408
}
407409

408410
/**
@@ -8132,6 +8134,33 @@ public Ops withControlDependencies(Iterable<Op> controls) {
81328134
return new Ops(scope.withControlDependencies(controls));
81338135
}
81348136

8137+
/**
8138+
* Returns an API that adds operations to the graph with the provided control dependencies.
8139+
*
8140+
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
8141+
*/
8142+
public Ops withControlDependencies(Op... controls) {
8143+
return withControlDependencies(Arrays.asList(controls));
8144+
}
8145+
8146+
/**
8147+
* Returns an API that adds operations to the graph with the provided control dependencies.
8148+
*
8149+
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
8150+
*/
8151+
public Ops withControlDependencyOps(Iterable<Operation> controls) {
8152+
return new Ops(scope.withControlDependencyOps(controls));
8153+
}
8154+
8155+
/**
8156+
* Returns an API that adds operations to the graph with the provided control dependencies.
8157+
*
8158+
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
8159+
*/
8160+
public Ops withControlDependencyOps(Operation... controls) {
8161+
return withControlDependencyOps(Arrays.asList(controls));
8162+
}
8163+
81358164
/**
81368165
* Returns the current {@link Scope scope} of this API
81378166
*/

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -541,20 +541,25 @@ public void importGraphDef(GraphDef graphDef, String prefix) throws IllegalArgum
541541
});
542542
}
543543

544-
private synchronized void addInitOp() {
545-
if (!newInitializers) {
546-
return;
544+
/**
545+
* Create and return a NoOp that will run all init ops. If {@code required} is false and there are
546+
* no new init ops since the last call, will do nothing and return null.
547+
*/
548+
synchronized GraphOperation addInitOp(boolean required) {
549+
if (!newInitializers && !required) {
550+
return null;
547551
}
548-
if (initializers.isEmpty()) {
549-
return;
552+
if (initializers.isEmpty() && !required) {
553+
return null;
550554
}
551555

552556
baseScope.refreshNames();
553557
OperationBuilder builder =
554558
baseScope().withInitScope().opBuilder(NoOp.OP_NAME, INIT_OP_BASE_NAME);
555559
initializers.forEach(builder::addControlInput);
556-
builder.build();
560+
GraphOperation initOp = (GraphOperation) builder.build();
557561
newInitializers = false;
562+
return initOp;
558563
}
559564

560565
/**
@@ -568,7 +573,7 @@ private synchronized void addInitOp() {
568573
* @see #importGraphDef(GraphDef, String)
569574
*/
570575
public GraphDef toGraphDef() {
571-
addInitOp();
576+
addInitOp(false);
572577
synchronized (nativeHandleLock) {
573578
return toGraphDef(nativeHandle);
574579
}
@@ -1239,6 +1244,8 @@ private static Object[] whileLoop(
12391244
private static SaverDef addVariableSaver(Graph graph) {
12401245
Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE);
12411246

1247+
// TODO handle resource variables, too
1248+
12421249
List<String> varNames = new ArrayList<>();
12431250
List<Operand<?>> varOutputs = new ArrayList<>();
12441251
List<Class<? extends TType>> varTypes = new ArrayList<>();

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@
4141
import org.tensorflow.internal.c_api.TF_Session;
4242
import org.tensorflow.internal.c_api.TF_SessionOptions;
4343
import org.tensorflow.internal.c_api.TF_Status;
44+
import org.tensorflow.proto.framework.CollectionDef;
45+
import org.tensorflow.proto.framework.CollectionDef.NodeList;
4446
import org.tensorflow.proto.framework.ConfigProto;
4547
import org.tensorflow.proto.framework.MetaGraphDef;
4648
import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
4749
import org.tensorflow.proto.framework.RunOptions;
4850
import org.tensorflow.proto.framework.SavedModel;
51+
import org.tensorflow.proto.framework.SignatureDef;
52+
import org.tensorflow.proto.framework.TensorInfo;
4953
import org.tensorflow.proto.util.SaverDef;
5054

5155
/**
@@ -64,6 +68,27 @@ public class SavedModelBundle implements AutoCloseable {
6468

6569
public static final String DEFAULT_TAG = "serve";
6670

71+
/** Signature used to track Java init ops, for our init scope. */
72+
private static final String JAVA_INIT_OP_SIGNATURE_KEY = "__saved_model_java_init_op_tracker";
73+
74+
/**
75+
* Tensorflow init op tracking signature. Init ops are executed before loading variables, so this
76+
* does not work for us.
77+
*/
78+
private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";
79+
80+
/**
81+
* A backup Tensorflow init op collection key. In TF1, init ops will be stored in collections
82+
* instead of signatures.
83+
*/
84+
private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op";
85+
86+
/** An even more legacy init op collection key. */
87+
private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op";
88+
89+
/** The collection where table initializers are stored in some hub models. */
90+
private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer";
91+
6792
/** Options for loading a SavedModel. */
6893
public static final class Loader {
6994

@@ -260,13 +285,29 @@ public void export() throws IOException {
260285
// new ops to the graph for saving and restoring the variables.
261286
SaverDef saverDef = graph.saverDef();
262287

288+
GraphOperation initOp = null;
289+
if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
290+
initOp = graph.addInitOp(true);
291+
}
292+
263293
MetaGraphDef.Builder metaGraphDef =
264294
metaGraphDefBuilder
265295
.setSaverDef(saverDef)
266296
.setGraphDef(graph.toGraphDef())
267297
.setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags)));
268298
functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef()));
269299

300+
if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
301+
302+
metaGraphDef.putSignatureDef(
303+
JAVA_INIT_OP_SIGNATURE_KEY,
304+
SignatureDef.newBuilder()
305+
.putOutputs(
306+
JAVA_INIT_OP_SIGNATURE_KEY,
307+
TensorInfo.newBuilder().setName(initOp.name() + ":0").build())
308+
.build());
309+
}
310+
270311
// Make sure saved model directories exist
271312
Path variableDir = Paths.get(exportDir, "variables");
272313
variableDir.toFile().mkdirs();
@@ -365,7 +406,14 @@ public Session session() {
365406

366407
/** Return the signature of all functions available in this saved model. */
367408
public List<Signature> signatures() {
368-
return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList());
409+
// the init signatures aren't actual functions, just markers
410+
return functions.values().stream()
411+
.map(SessionFunction::signature)
412+
.filter(
413+
signature ->
414+
!signature.key().equals(INIT_OP_SIGNATURE_KEY)
415+
&& !signature.key().equals(JAVA_INIT_OP_SIGNATURE_KEY))
416+
.collect(Collectors.toList());
369417
}
370418

371419
/**
@@ -459,6 +507,32 @@ private SavedModelBundle(
459507
Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session)));
460508
}
461509

510+
private static GraphOperation findInitOp(
511+
Graph graph, Map<String, Signature> signatures, Map<String, CollectionDef> collections) {
512+
513+
Signature initSig = signatures.get(INIT_OP_SIGNATURE_KEY);
514+
if (initSig != null) {
515+
return (GraphOperation)
516+
graph.outputOrThrow(initSig.getOutputs().get(INIT_OP_SIGNATURE_KEY).name).op();
517+
}
518+
519+
CollectionDef initCollection;
520+
if (collections.containsKey(MAIN_OP_COLLECTION_KEY)) {
521+
initCollection = collections.get(MAIN_OP_COLLECTION_KEY);
522+
} else {
523+
initCollection = collections.get(LEGACY_INIT_OP_COLLECTION_KEY);
524+
}
525+
526+
if (initCollection != null) {
527+
NodeList nodes = initCollection.getNodeList();
528+
if (nodes.getValueCount() != 1) {
529+
throw new IllegalArgumentException("Expected exactly one main op in saved model.");
530+
}
531+
return (GraphOperation) graph.outputOrThrow(nodes.getValue(0)).op();
532+
}
533+
return null;
534+
}
535+
462536
/**
463537
* Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session
464538
* object, plus the MetaGraphDef.
@@ -486,6 +560,41 @@ private static SavedModelBundle fromHandle(
486560
functions.put(signatureName, signature);
487561
}
488562
});
563+
564+
GraphOperation initOp = findInitOp(graph, functions, metaGraphDef.getCollectionDefMap());
565+
if (initOp != null) {
566+
graph.registerInitOp(initOp);
567+
}
568+
569+
// java init ops are marked as ran, since the variable restore will restore any state
570+
// they mutated.
571+
// Technically, init ops should be ran first, then variable restore, but that is not possible
572+
// since TF_Session.loadSessionFromSavedModel does it in reverse order, so we just mark them as
573+
// ran.
574+
if (functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
575+
String initOpName =
576+
functions
577+
.get(JAVA_INIT_OP_SIGNATURE_KEY)
578+
.getOutputs()
579+
.get(JAVA_INIT_OP_SIGNATURE_KEY)
580+
.name;
581+
graph.registerInitOp(graph.outputOrThrow(initOpName).op());
582+
}
583+
584+
session.setInitialized();
585+
586+
if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) {
587+
metaGraphDef
588+
.getCollectionDefMap()
589+
.get(TABLE_INITIALIZERS_COLLECTION_KEY)
590+
.getNodeList()
591+
.getValueList()
592+
.forEach(
593+
node -> {
594+
graph.registerInitOp(graph.operationOrThrow(node));
595+
});
596+
}
597+
489598
return new SavedModelBundle(graph, session, metaGraphDef, functions);
490599
}
491600

@@ -525,6 +634,7 @@ private static SavedModelBundle load(
525634
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
526635
}
527636
}
637+
bundle.session.initialize();
528638

529639
return bundle;
530640
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,21 @@ public void close() {
186186
public void initialize() {
187187
Runner runner = runner();
188188
graph.initializers().stream().filter((x) -> !ranInits.contains(x)).forEach(runner::addTarget);
189-
ranInits.clear();
190-
ranInits.addAll(graph.initializers());
189+
setInitialized();
191190
if (!runner.isEmpty()) {
192191
runner.runNoInit();
193192
}
194193
}
195194

195+
/**
196+
* Set the ran initializers to all initializers in the graph, as if they had been run. <b>Does not
197+
* actually ensure they are ran.</b>
198+
*/
199+
void setInitialized() {
200+
ranInits.clear();
201+
ranInits.addAll(graph.initializers());
202+
}
203+
196204
/**
197205
* Execute the graph's initializers, regardless of whether the session has been initialized.
198206
*
@@ -686,8 +694,7 @@ public void restore(String prefix) {
686694
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
687695
.runNoInit();
688696
// TODO better way of doing this, only count as ran assignments to the restored variables.
689-
ranInits.clear();
690-
ranInits.addAll(graph.initializers());
697+
setInitialized();
691698
}
692699

693700
/**

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,8 @@ public void exportFunctionWithVariables() throws IOException {
179179
assertEquals("save/control_dependency", saverDef.getSaveTensorName());
180180
assertEquals("save/restore_all", saverDef.getRestoreOpName());
181181

182-
assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
183-
assertEquals(
184-
Signature.DEFAULT_KEY,
185-
savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next());
182+
assertEquals(2, savedModel.metaGraphDef().getSignatureDefCount());
183+
assertTrue(savedModel.metaGraphDef().getSignatureDefMap().containsKey(Signature.DEFAULT_KEY));
186184

187185
TensorFunction function = savedModel.function(Signature.DEFAULT_KEY);
188186
assertNotNull(function);
@@ -269,18 +267,15 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti
269267

270268
@Test
271269
public void cannotExportOrImportInvalidTags() {
272-
assertThrows(IllegalArgumentException.class, () ->
273-
SavedModelBundle.loader("/").withTags(null)
274-
);
275-
assertThrows(IllegalArgumentException.class, () ->
276-
SavedModelBundle.loader("/").withTags(new String[]{"tag", null})
277-
);
278-
assertThrows(IllegalArgumentException.class, () ->
279-
SavedModelBundle.exporter("/").withTags(null)
280-
);
281-
assertThrows(IllegalArgumentException.class, () ->
282-
SavedModelBundle.exporter("/").withTags(new String[]{"tag", null})
283-
);
270+
assertThrows(IllegalArgumentException.class, () -> SavedModelBundle.loader("/").withTags(null));
271+
assertThrows(
272+
IllegalArgumentException.class,
273+
() -> SavedModelBundle.loader("/").withTags(new String[] {"tag", null}));
274+
assertThrows(
275+
IllegalArgumentException.class, () -> SavedModelBundle.exporter("/").withTags(null));
276+
assertThrows(
277+
IllegalArgumentException.class,
278+
() -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", null}));
284279
}
285280

286281
@Test

0 commit comments

Comments
 (0)