67
67
public class SavedModelBundle implements AutoCloseable {
68
68
69
69
public static final String DEFAULT_TAG = "serve" ;
70
+ private static final String JAVA_INIT_OP_SIGNATURE_KEY = "__saved_model_java_init_op_tracker" ;
70
71
private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op" ;
71
72
private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op" ;
72
73
private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op" ;
@@ -269,7 +270,7 @@ public void export() throws IOException {
269
270
SaverDef saverDef = graph .saverDef ();
270
271
271
272
GraphOperation initOp = null ;
272
- if (!functions .containsKey (INIT_OP_SIGNATURE_KEY )) {
273
+ if (!functions .containsKey (JAVA_INIT_OP_SIGNATURE_KEY )) {
273
274
initOp = graph .addInitOp (true );
274
275
}
275
276
@@ -280,13 +281,13 @@ public void export() throws IOException {
280
281
.setMetaInfoDef (MetaInfoDef .newBuilder ().addAllTags (Arrays .asList (tags )));
281
282
functions .forEach ((k , f ) -> metaGraphDef .putSignatureDef (k , f .signature ().asSignatureDef ()));
282
283
283
- if (!functions .containsKey (INIT_OP_SIGNATURE_KEY )) {
284
+ if (!functions .containsKey (JAVA_INIT_OP_SIGNATURE_KEY )) {
284
285
285
286
metaGraphDef .putSignatureDef (
286
- INIT_OP_SIGNATURE_KEY ,
287
+ JAVA_INIT_OP_SIGNATURE_KEY ,
287
288
SignatureDef .newBuilder ()
288
289
.putOutputs (
289
- INIT_OP_SIGNATURE_KEY ,
290
+ JAVA_INIT_OP_SIGNATURE_KEY ,
290
291
TensorInfo .newBuilder ().setName (initOp .name () + ":0" ).build ())
291
292
.build ());
292
293
}
@@ -544,6 +545,17 @@ private static SavedModelBundle fromHandle(
544
545
if (initOp != null ) {
545
546
graph .registerInitOp (initOp );
546
547
}
548
+
549
+ // java init ops are marked as ran, since the variable restore will restore any state
550
+ // they mutated.
551
+ // Technically, init ops should be ran first, then variable restore, but that is not possible
552
+ // since TF_Session.loadSessionFromSavedModel does it in reverse order, so we just mark them as
553
+ // ran.
554
+ if (functions .containsKey (JAVA_INIT_OP_SIGNATURE_KEY )){
555
+ String initOpName = functions .get (JAVA_INIT_OP_SIGNATURE_KEY ).getOutputs ().get (JAVA_INIT_OP_SIGNATURE_KEY ).name ;
556
+ graph .registerInitOp (graph .outputOrThrow (initOpName ).op ());
557
+ }
558
+
547
559
session .setInitialized ();
548
560
549
561
if (metaGraphDef .containsCollectionDef (TABLE_INITIALIZERS_COLLECTION_KEY )) {
@@ -599,8 +611,6 @@ private static SavedModelBundle load(
599
611
}
600
612
bundle .session .initialize ();
601
613
602
- // bundle.session.restore(exportDir + "/variables/variables");
603
-
604
614
return bundle ;
605
615
}
606
616
0 commit comments