Skip to content

Commit 266e2b9

Browse files
committed
Don't export java init scope as init ops
Signed-off-by: Ryan Nett <[email protected]>
1 parent 8bac25e commit 266e2b9

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
public class SavedModelBundle implements AutoCloseable {
6868

6969
public static final String DEFAULT_TAG = "serve";
70+
private static final String JAVA_INIT_OP_SIGNATURE_KEY = "__saved_model_java_init_op_tracker";
7071
private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";
7172
private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op";
7273
private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op";
@@ -269,7 +270,7 @@ public void export() throws IOException {
269270
SaverDef saverDef = graph.saverDef();
270271

271272
GraphOperation initOp = null;
272-
if (!functions.containsKey(INIT_OP_SIGNATURE_KEY)) {
273+
if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
273274
initOp = graph.addInitOp(true);
274275
}
275276

@@ -280,13 +281,13 @@ public void export() throws IOException {
280281
.setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags)));
281282
functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef()));
282283

283-
if (!functions.containsKey(INIT_OP_SIGNATURE_KEY)) {
284+
if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
284285

285286
metaGraphDef.putSignatureDef(
286-
INIT_OP_SIGNATURE_KEY,
287+
JAVA_INIT_OP_SIGNATURE_KEY,
287288
SignatureDef.newBuilder()
288289
.putOutputs(
289-
INIT_OP_SIGNATURE_KEY,
290+
JAVA_INIT_OP_SIGNATURE_KEY,
290291
TensorInfo.newBuilder().setName(initOp.name() + ":0").build())
291292
.build());
292293
}
@@ -544,6 +545,17 @@ private static SavedModelBundle fromHandle(
544545
if (initOp != null) {
545546
graph.registerInitOp(initOp);
546547
}
548+
549+
// java init ops are marked as ran, since the variable restore will restore any state
550+
// they mutated.
551+
// Technically, init ops should be ran first, then variable restore, but that is not possible
552+
// since TF_Session.loadSessionFromSavedModel does it in reverse order, so we just mark them as
553+
// ran.
554+
if(functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)){
555+
String initOpName = functions.get(JAVA_INIT_OP_SIGNATURE_KEY).getOutputs().get(JAVA_INIT_OP_SIGNATURE_KEY).name;
556+
graph.registerInitOp(graph.outputOrThrow(initOpName).op());
557+
}
558+
547559
session.setInitialized();
548560

549561
if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) {
@@ -599,8 +611,6 @@ private static SavedModelBundle load(
599611
}
600612
bundle.session.initialize();
601613

602-
// bundle.session.restore(exportDir + "/variables/variables");
603-
604614
return bundle;
605615
}
606616

0 commit comments

Comments
 (0)