Skip to content

Init exporting and loading #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 13, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
package org.tensorflow.op;

import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.DeviceSpec;
import org.tensorflow.EagerSession;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
import org.tensorflow.ndarray.BooleanNdArray;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.DoubleNdArray;
Expand Down Expand Up @@ -365,20 +367,20 @@ public final class Ops {

public final SparseOps sparse;

public final TpuOps tpu;

public final BitwiseOps bitwise;

public final TpuOps tpu;

public final MathOps math;

public final AudioOps audio;

public final SignalOps signal;

public final QuantizationOps quantization;

public final TrainOps train;

public final QuantizationOps quantization;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -396,13 +398,13 @@ private Ops(Scope scope) {
random = new RandomOps(this);
strings = new StringsOps(this);
sparse = new SparseOps(this);
tpu = new TpuOps(this);
bitwise = new BitwiseOps(this);
tpu = new TpuOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
}

/**
Expand Down Expand Up @@ -8132,6 +8134,33 @@ public Ops withControlDependencies(Iterable<Op> controls) {
return new Ops(scope.withControlDependencies(controls));
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
*/
public Ops withControlDependencies(Op... controls) {
return withControlDependencies(Arrays.asList(controls));
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
*/
public Ops withControlDependencyOps(Iterable<Operation> controls) {
return new Ops(scope.withControlDependencyOps(controls));
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
*/
public Ops withControlDependencyOps(Operation... controls) {
return withControlDependencyOps(Arrays.asList(controls));
}

/**
* Returns the current {@link Scope scope} of this API
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,20 +541,25 @@ public void importGraphDef(GraphDef graphDef, String prefix) throws IllegalArgum
});
}

private synchronized void addInitOp() {
if (!newInitializers) {
return;
/**
* Create and return a NoOp that will run all init ops. If {@code required} is false and there are
* no new init ops since the last call, will do nothing and return null.
*/
synchronized GraphOperation addInitOp(boolean required) {
if (!newInitializers && !required) {
return null;
}
if (initializers.isEmpty()) {
return;
if (initializers.isEmpty() && !required) {
return null;
}

baseScope.refreshNames();
OperationBuilder builder =
baseScope().withInitScope().opBuilder(NoOp.OP_NAME, INIT_OP_BASE_NAME);
initializers.forEach(builder::addControlInput);
builder.build();
GraphOperation initOp = (GraphOperation) builder.build();
newInitializers = false;
return initOp;
}

/**
Expand All @@ -568,7 +573,7 @@ private synchronized void addInitOp() {
* @see #importGraphDef(GraphDef, String)
*/
public GraphDef toGraphDef() {
addInitOp();
addInitOp(false);
synchronized (nativeHandleLock) {
return toGraphDef(nativeHandle);
}
Expand Down Expand Up @@ -1239,6 +1244,8 @@ private static Object[] whileLoop(
private static SaverDef addVariableSaver(Graph graph) {
Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE);

// TODO handle resource variables, too

List<String> varNames = new ArrayList<>();
List<Operand<?>> varOutputs = new ArrayList<>();
List<Class<? extends TType>> varTypes = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.proto.framework.CollectionDef;
import org.tensorflow.proto.framework.CollectionDef.NodeList;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.framework.SavedModel;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.util.SaverDef;

/**
Expand All @@ -64,6 +68,27 @@ public class SavedModelBundle implements AutoCloseable {

public static final String DEFAULT_TAG = "serve";

/** Signature used to track Java init ops, for our init scope. */
private static final String JAVA_INIT_OP_SIGNATURE_KEY = "__saved_model_java_init_op_tracker";

/**
* Tensorflow init op tracking signature. Init ops are executed before loading variables, so this
* does not work for us.
*/
private static final String INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";

/**
* A backup Tensorflow init op collection key. In TF1, init ops will be stored in collections
* instead of signatures.
*/
private static final String MAIN_OP_COLLECTION_KEY = "saved_model_main_op";

/** An even more legacy init op collection key. */
private static final String LEGACY_INIT_OP_COLLECTION_KEY = "legacy_init_op";

/** The collection where table initializers are stored in some hub models. */
private static final String TABLE_INITIALIZERS_COLLECTION_KEY = "table_initializer";

/** Options for loading a SavedModel. */
public static final class Loader {

Expand Down Expand Up @@ -260,13 +285,29 @@ public void export() throws IOException {
// new ops to the graph for saving and restoring the variables.
SaverDef saverDef = graph.saverDef();

GraphOperation initOp = null;
if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
initOp = graph.addInitOp(true);
}

MetaGraphDef.Builder metaGraphDef =
metaGraphDefBuilder
.setSaverDef(saverDef)
.setGraphDef(graph.toGraphDef())
.setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags)));
functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef()));

if (!functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {

metaGraphDef.putSignatureDef(
JAVA_INIT_OP_SIGNATURE_KEY,
SignatureDef.newBuilder()
.putOutputs(
JAVA_INIT_OP_SIGNATURE_KEY,
TensorInfo.newBuilder().setName(initOp.name() + ":0").build())
.build());
}

// Make sure saved model directories exist
Path variableDir = Paths.get(exportDir, "variables");
variableDir.toFile().mkdirs();
Expand Down Expand Up @@ -365,7 +406,14 @@ public Session session() {

/** Return the signature of all functions available in this saved model. */
public List<Signature> signatures() {
return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList());
// the init signatures aren't actual functions, just markers
return functions.values().stream()
.map(SessionFunction::signature)
.filter(
signature ->
!signature.key().equals(INIT_OP_SIGNATURE_KEY)
&& !signature.key().equals(JAVA_INIT_OP_SIGNATURE_KEY))
.collect(Collectors.toList());
}

/**
Expand Down Expand Up @@ -459,6 +507,32 @@ private SavedModelBundle(
Collectors.toMap(Entry::getKey, e -> new SessionFunction(e.getValue(), session)));
}

private static GraphOperation findInitOp(
Graph graph, Map<String, Signature> signatures, Map<String, CollectionDef> collections) {

Signature initSig = signatures.get(INIT_OP_SIGNATURE_KEY);
if (initSig != null) {
return (GraphOperation)
graph.outputOrThrow(initSig.getOutputs().get(INIT_OP_SIGNATURE_KEY).name).op();
}

CollectionDef initCollection;
if (collections.containsKey(MAIN_OP_COLLECTION_KEY)) {
initCollection = collections.get(MAIN_OP_COLLECTION_KEY);
} else {
initCollection = collections.get(LEGACY_INIT_OP_COLLECTION_KEY);
}

if (initCollection != null) {
NodeList nodes = initCollection.getNodeList();
if (nodes.getValueCount() != 1) {
throw new IllegalArgumentException("Expected exactly one main op in saved model.");
}
return (GraphOperation) graph.outputOrThrow(nodes.getValue(0)).op();
}
return null;
}

/**
* Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session
* object, plus the MetaGraphDef.
Expand Down Expand Up @@ -486,6 +560,41 @@ private static SavedModelBundle fromHandle(
functions.put(signatureName, signature);
}
});

GraphOperation initOp = findInitOp(graph, functions, metaGraphDef.getCollectionDefMap());
if (initOp != null) {
graph.registerInitOp(initOp);
}

// java init ops are marked as ran, since the variable restore will restore any state
// they mutated.
// Technically, init ops should be ran first, then variable restore, but that is not possible
// since TF_Session.loadSessionFromSavedModel does it in reverse order, so we just mark them as
// ran.
if (functions.containsKey(JAVA_INIT_OP_SIGNATURE_KEY)) {
String initOpName =
functions
.get(JAVA_INIT_OP_SIGNATURE_KEY)
.getOutputs()
.get(JAVA_INIT_OP_SIGNATURE_KEY)
.name;
graph.registerInitOp(graph.outputOrThrow(initOpName).op());
}

session.setInitialized();

if (metaGraphDef.containsCollectionDef(TABLE_INITIALIZERS_COLLECTION_KEY)) {
metaGraphDef
.getCollectionDefMap()
.get(TABLE_INITIALIZERS_COLLECTION_KEY)
.getNodeList()
.getValueList()
.forEach(
node -> {
graph.registerInitOp(graph.operationOrThrow(node));
});
}

return new SavedModelBundle(graph, session, metaGraphDef, functions);
}

Expand Down Expand Up @@ -525,6 +634,7 @@ private static SavedModelBundle load(
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
}
}
bundle.session.initialize();

return bundle;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,21 @@ public void close() {
public void initialize() {
Runner runner = runner();
graph.initializers().stream().filter((x) -> !ranInits.contains(x)).forEach(runner::addTarget);
ranInits.clear();
ranInits.addAll(graph.initializers());
setInitialized();
if (!runner.isEmpty()) {
runner.runNoInit();
}
}

/**
* Set the ran initializers to all initializers in the graph, as if they had been run. <b>Does not
* actually ensure they are ran.</b>
*/
void setInitialized() {
ranInits.clear();
ranInits.addAll(graph.initializers());
}

/**
* Execute the graph's initializers, regardless of whether the session has been initialized.
*
Expand Down Expand Up @@ -686,8 +694,7 @@ public void restore(String prefix) {
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.runNoInit();
// TODO better way of doing this, only count as ran assignments to the restored variables.
ranInits.clear();
ranInits.addAll(graph.initializers());
setInitialized();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,8 @@ public void exportFunctionWithVariables() throws IOException {
assertEquals("save/control_dependency", saverDef.getSaveTensorName());
assertEquals("save/restore_all", saverDef.getRestoreOpName());

assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
assertEquals(
Signature.DEFAULT_KEY,
savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next());
assertEquals(2, savedModel.metaGraphDef().getSignatureDefCount());
assertTrue(savedModel.metaGraphDef().getSignatureDefMap().containsKey(Signature.DEFAULT_KEY));

TensorFunction function = savedModel.function(Signature.DEFAULT_KEY);
assertNotNull(function);
Expand Down Expand Up @@ -269,18 +267,15 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti

@Test
public void cannotExportOrImportInvalidTags() {
assertThrows(IllegalArgumentException.class, () ->
SavedModelBundle.loader("/").withTags(null)
);
assertThrows(IllegalArgumentException.class, () ->
SavedModelBundle.loader("/").withTags(new String[]{"tag", null})
);
assertThrows(IllegalArgumentException.class, () ->
SavedModelBundle.exporter("/").withTags(null)
);
assertThrows(IllegalArgumentException.class, () ->
SavedModelBundle.exporter("/").withTags(new String[]{"tag", null})
);
assertThrows(IllegalArgumentException.class, () -> SavedModelBundle.loader("/").withTags(null));
assertThrows(
IllegalArgumentException.class,
() -> SavedModelBundle.loader("/").withTags(new String[] {"tag", null}));
assertThrows(
IllegalArgumentException.class, () -> SavedModelBundle.exporter("/").withTags(null));
assertThrows(
IllegalArgumentException.class,
() -> SavedModelBundle.exporter("/").withTags(new String[] {"tag", null}));
}

@Test
Expand Down
Loading