Skip to content

Commit 27ac746

Browse files
committed
Fixes
Signed-off-by: Ryan Nett <[email protected]>
1 parent 266e2b9 commit 27ac746

File tree

2 files changed

+8
-30
lines changed

2 files changed

+8
-30
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,13 @@ private static SavedModelBundle fromHandle(
551551
// Technically, init ops should be ran first, then variable restore, but that is not possible
552552
// since TF_Session.loadSessionFromSavedModel does it in reverse order, so we just mark them as
553553
// ran.
554-
if(functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)){
555-
String initOpName = functions.get(JAVA_INIT_OP_SIGNATURE_KEY).getOutputs().get(JAVA_INIT_OP_SIGNATURE_KEY).name;
554+
if (functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
555+
String initOpName =
556+
functions
557+
.get(JAVA_INIT_OP_SIGNATURE_KEY)
558+
.getOutputs()
559+
.get(JAVA_INIT_OP_SIGNATURE_KEY)
560+
.name;
556561
graph.registerInitOp(graph.outputOrThrow(initOpName).op());
557562
}
558563

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public void exportMultipleFunctions() throws IOException {
121121
}
122122
}
123123
try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) {
124-
assertEquals(2, model.signatures().size());
124+
assertEquals(3, model.signatures().size());
125125
TensorFunction f1 = model.function(Signature.DEFAULT_KEY);
126126
assertNotNull(f1);
127127
try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2}));
@@ -319,33 +319,6 @@ public void pythonTfFunction() {
319319
}
320320
}
321321

322-
@Test
323-
public void exportAndLoadInitializers() throws IOException {
324-
Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
325-
try (Graph g = new Graph();
326-
Session s = new Session(g)) {
327-
Ops tf = Ops.create(g);
328-
Ops init = tf.withInitScope();
329-
Operand<?> handle = init.withName("variable").varHandleOp(TInt32.class, Shape.scalar());
330-
init.withName("init").assignVariableOp(handle, init.constant(10));
331-
332-
SessionFunction f =
333-
SessionFunction.create(
334-
Signature.builder()
335-
.key("f")
336-
.output("out", tf.withName("read").readVariableOp(handle, TInt32.class))
337-
.build(),
338-
s);
339-
340-
SavedModelBundle.exporter(testFolder.toString()).withFunction(f).export();
341-
}
342-
343-
try (SavedModelBundle savedModel = SavedModelBundle.load(testFolder.toString())) {
344-
TInt32 tensor = (TInt32) savedModel.session().runner().fetch("read", 0).run().get(0);
345-
assertEquals(10, tensor.getInt());
346-
}
347-
}
348-
349322
private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
350323
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape));
351324
Variable<TFloat32> y =

0 commit comments

Comments
 (0)