Skip to content

Init Scope #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3f5817b
Start of init scope
rnett Jun 16, 2021
f2c6000
Fix tests
rnett Jun 16, 2021
342237c
Javadoc updates
rnett Jun 16, 2021
5502ec3
Session init helpers
rnett Jun 16, 2021
abc5617
Format fixes
rnett Jun 16, 2021
c959696
Make initEnv default to this.
rnett Jun 16, 2021
fae97bb
More formatting fixes
rnett Jun 16, 2021
e8fbee6
Small fixes, add native pointer based equals and hashCode to EagerOpe…
rnett Jun 17, 2021
feb90fc
Export init ops to GraphDefs and import from them
rnett Jun 18, 2021
bb2b05c
Test adding init ops after import
rnett Jun 18, 2021
94d48c3
Automatically lift constants to init if required
rnett Jun 18, 2021
4e97df5
Add withInitScope
rnett Jun 18, 2021
2dbaebf
Allow init ops to depend on other init ops
rnett Jun 18, 2021
18ef780
Add void withInitScope
rnett Jun 18, 2021
1c8b463
Lift init inputs to init as well
rnett Jun 18, 2021
8b43786
Formatting
rnett Jun 18, 2021
4adb76b
Allow use of init input lifting
rnett Jun 18, 2021
89a4fda
Add forceInitialize to reinitialize session
rnett Jun 18, 2021
cc3ebcc
Replace withInitScope with liftToInitScope
rnett Jun 18, 2021
27ce339
Update framework
rnett Jun 18, 2021
3f2593a
Rebase fixes
rnett Jul 7, 2021
b8e281e
Remove Java 10 API
rnett Jul 16, 2021
067a4c0
Check control dependencies up front
rnett Jul 16, 2021
669446c
Reorder methods
rnett Jul 16, 2021
d0ccd67
Session constructors
rnett Jul 16, 2021
2946251
Variable with init value uses passed scope
rnett Jul 16, 2021
2a6d1ec
Change initScope to withInitScope
rnett Jul 16, 2021
2ac92ce
Fix initScope usages
rnett Jul 16, 2021
e32dd31
Variable fixes
rnett Jul 16, 2021
27e2fd5
Don't put reset ops in init scope
rnett Jul 16, 2021
b340659
New session initialization, formatting
rnett Jul 16, 2021
7130e57
Fix a test
rnett Jul 16, 2021
53d54f0
Fix generated names
rnett Jul 17, 2021
e958e80
Comment fixes
rnett Jul 26, 2021
550892a
More comment fixes
rnett Jul 26, 2021
6486fb9
Track init ops and topmost init ops separately
rnett Jul 26, 2021
c24f247
Don't track topmost init ops
rnett Jul 26, 2021
d7adebb
Unsynchronize Session
rnett Jul 26, 2021
819f6e0
Fix format
rnett Jul 26, 2021
e663db7
Fix not building init op
rnett Jul 26, 2021
9436f2b
Use init scope in variable-with-init
rnett Jul 30, 2021
dd6e94f
Fix comments, make ranInits final
rnett Jul 30, 2021
7acc6b5
Update toGraphDef comment
rnett Jul 30, 2021
a16c11e
Update Sessions comment
rnett Jul 30, 2021
07143ae
Remove wildcard import
rnett Jul 31, 2021
430e0ff
Change init op name to Init
rnett Aug 11, 2021
7df071c
Don't include NoOps in initializer list
rnett Aug 11, 2021
2ecdebf
Fix format
rnett Aug 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
import org.tensorflow.op.core.IdentityN;
import org.tensorflow.op.core.If;
import org.tensorflow.op.core.ImmutableConst;
import org.tensorflow.op.core.Init;
import org.tensorflow.op.core.InitializeTable;
import org.tensorflow.op.core.InitializeTableFromTextFile;
import org.tensorflow.op.core.InplaceAdd;
Expand Down Expand Up @@ -295,6 +294,12 @@
import org.tensorflow.op.core.VariableShape;
import org.tensorflow.op.core.Where;
import org.tensorflow.op.core.While;
import org.tensorflow.op.core.XlaConvV2;
import org.tensorflow.op.core.XlaDotV2;
import org.tensorflow.op.core.XlaSetDynamicDimensionSize;
import org.tensorflow.op.core.XlaSpmdFullToShardShape;
import org.tensorflow.op.core.XlaSpmdShardToFullShape;
import org.tensorflow.op.core.XlaVariadicSort;
import org.tensorflow.op.core.Zeros;
import org.tensorflow.op.core.ZerosLike;
import org.tensorflow.types.TBool;
Expand Down Expand Up @@ -366,20 +371,20 @@ public final class Ops {

public final SparseOps sparse;

public final TpuOps tpu;

public final BitwiseOps bitwise;

public final TpuOps tpu;

public final MathOps math;

public final AudioOps audio;

public final SignalOps signal;

public final QuantizationOps quantization;

public final TrainOps train;

public final QuantizationOps quantization;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -397,13 +402,13 @@ private Ops(Scope scope) {
random = new RandomOps(this);
strings = new StringsOps(this);
sparse = new SparseOps(this);
tpu = new TpuOps(this);
bitwise = new BitwiseOps(this);
tpu = new TpuOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
}

/**
Expand Down Expand Up @@ -1952,14 +1957,15 @@ public Constant<TFloat32> constant(Shape shape, FloatDataBuffer data) {
}

/**
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
* fit in the target type.
* Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be
* truncated if it does not fit in the target type.
*
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
* @param type the type of tensor to create. Must be concrete (i.e. not {@link
* org.tensorflow.types.family.TFloating})
* @param number the value of the tensor
* @return a constant of the passed type
* @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
* unknown.
* @throws IllegalArgumentException if the type is abstract (i.e. {@link
* org.tensorflow.types.family.TFloating}) or unknown.
*/
public <T extends TNumber> Constant<T> constant(Class<T> type, Number number) {
return Constant.tensorOf(scope, type, number);
Expand Down Expand Up @@ -1994,11 +2000,12 @@ public <T extends TType> Constant<T> constant(Class<T> type, Shape shape, ByteDa
}

/**
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without
* issue.
* Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed
* afterwards without issue.
*
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope,
* FloatNdArray)}}.
*
* @param tensor a Tensor holding the constant value
* @return a constant of the same data type as `tensor`
Expand All @@ -2008,8 +2015,8 @@ public <T extends TType> Constant<T> constantOf(T tensor) {
}

/**
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
* truncated if it does not fit in the target type.
* Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code
* number} may be truncated if it does not fit in the target type.
*
* @param toMatch the operand providing the target type
* @param number the value of the tensor
Expand Down Expand Up @@ -2993,80 +3000,6 @@ public <T extends TType> ImmutableConst<T> immutableConst(Class<T> dtype, Shape
return ImmutableConst.create(scope, dtype, shape, memoryRegionName);
}

/**
* Factory method to create an operation executing all initializers of a graph.
*
* <p>All initializers added to a graph via
* {@link org.tensorflow.op.core.Init#add(Scope, Op) tf.initAdd} are grouped together as a single
* unit of computation in the graph. This operation must then be added to any graph using one or
* more {@link Variable variables} and executed once before running the graph so the variable
* states are initialized properly.</p>
*
* <p>When the graph is built by the same process that is running the session, the initializers
* can be invoked by executing this single endpoint. For example:</p>
* <pre>{@code
* try (Graph g = new Graph()) {
* Variable<TInt32> x = tf.variable(tf.constant(10)); // initAdd is called implicitly
* Variable<TInt32> y = tf.variable(tf.constant(20)); // idem
* Add<TInt32> z = tf.math.add(x, y);
*
* try (Session s = new Session(g)) {
* s.run(tf.init()); // initialize all variables
*
* try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) {
* assertEquals(30, t.data().getInt());
* }
* }
* }
* }</pre>
*
* <p>When the graph is built by a separate process, the initializers can be invoked by running
* the init op by its name, which defaults to {@link org.tensorflow.op.core.Init#DEFAULT_NAME}.
* For example:</p>
* <pre>{@code
* // Building the model
* try (Graph g = new Graph()) {
* Variable<TInt32> x = tf.variable(tf.constant(10)); // initAdd is called implicitly
* Variable<TInt32> y = tf.variable(tf.constant(20)); // idem
* Add<TInt32> z = tf.withName("z").math.add(x, y);
*
* tf.init(); // add variables initializers to the graph, as Init.DEFAULT_NAME
* // ...exporting graph as a saved model...
* }
*
* ...
*
* // Running the model
* try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) {
* model.session().run(Init.DEFAULT_NAME);
*
* try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) {
* assertEquals(30, t.data().getInt());
* }
* }
* }</pre>
*
* @return an op grouping all initializers added to the graph
* @throws IllegalArgumentException if the execution environment in scope is not a graph
*/
public Init init() {
return Init.create(scope);
}

/**
* Register an op as an initializer of the graph.
*
* <p>Registered initializers are then grouped as a single unit of computation by adding
* and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph
* session. This is a no-op if executed in an eager session.
*
* @param initializer
* @see org.tensorflow.op.core.Init#create(Scope) init
*/
public void initAdd(Op initializer) {
Init.add(scope, initializer);
}

/**
* Table initializer that takes two tensors for keys and values respectively.
*
Expand Down Expand Up @@ -7947,10 +7880,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand<? extends TType> resource)
}

/**
* Factory method to create a new Variable with it's initializer.
* <p>
* Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op
* does not work in an EagerSession.
* Factory method to create a new Variable with its initializer. Both the creation and assignment
* are done in the init scope.
*
* <p>Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op does not
* work in an EagerSession.
*
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
Expand Down Expand Up @@ -8143,6 +8077,37 @@ public Ops withSubScope(String childScopeName) {
return new Ops(scope.withSubScope(childScopeName));
}

/**
* Returns an API that builds init operations. {@link #liftToInitScope(Operand)} will be called for all created operations.
* <p>
* Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,
* and are ignored when used as control dependencies.
* Additionally, this scope ignores any control dependencies.
* <p>
* If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation.
* @see #liftToInitScope(Operand)
*/
public Ops withInitScope() {
return new Ops(scope.withInitScope());
}

/**
* Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).
* <p>
* Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,
* and are ignored when used as control dependencies.
* Additionally, this scope ignores any control dependencies.
* <p>
* If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation.
* @see ExecutionEnvironment#registerInitOp(Operation)
*
* @throws IllegalStateException if the op or one of its inputs can't be made an init op.
*/
public <T extends Operand> T liftToInitScope(T op) {
scope.env().registerInitOp(op.op());
return op;
}

/**
* Returns an API that uses the provided name for an op.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
return outputHandles[outputIndex];
}

@Override
public int hashCode() {
return Long.valueOf(opHandle.address()).hashCode();
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (!(o instanceof EagerOperation)) {
return false;
}
EagerOperation that = (EagerOperation) o;
if (session != that.session) {
return false;
}

if (opHandle == null || that.opHandle == null || opHandle.isNull() || that.opHandle.isNull()) {
// if they are the same object, we will already have returned
return false;
}
return opHandle.equals(that.opHandle);
}

@Override
Shape shape(int outputIndex) {
// If the tensor of this output has already been resolved, return its shape.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow;

import static org.tensorflow.internal.c_api.global.tensorflow.TFE_Execute;
Expand Down Expand Up @@ -53,24 +53,29 @@
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.proto.framework.DataType;

/**
* An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly.
*/
final class EagerOperationBuilder implements OperationBuilder {

EagerOperationBuilder(EagerSession session, String type, String name) {
EagerOperationBuilder(EagerSession session, String type, String name, Scope scope) {
this.session = session;
this.type = type;
this.name = name;
this.scope = scope;
this.opHandle = allocate(session, type);
}

@Override
public EagerOperation build() {
scope.apply(this);
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
return new EagerOperation(session, opHandle, tensorHandles, type, name);
EagerOperation op = new EagerOperation(session, opHandle, tensorHandles, type, name);
scope.onOpCreated(op);
return op;
}

@Override
Expand Down Expand Up @@ -250,6 +255,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
private final EagerSession session;
private final String type;
private final String name;
private final Scope scope;

/** This value should be >= to the maximum number of outputs in any op */
private static final int MAX_OUTPUTS_PER_OP = 1000;
Expand Down
Loading