Skip to content

[WIP] tf.Variable like API for variables #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@
import org.tensorflow.op.core.Unstage;
import org.tensorflow.op.core.VarHandleOp;
import org.tensorflow.op.core.VarIsInitializedOp;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.core.VariableShape;
import org.tensorflow.op.core.Where;
import org.tensorflow.op.core.XlaSpmdFullToShardShape;
Expand All @@ -294,6 +293,7 @@
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;
import org.tensorflow.variable.Variable;

/**
* An API for building operations as {@link Op Op}s
Expand Down Expand Up @@ -394,6 +394,43 @@ private Ops(Scope scope) {
quantization = new QuantizationOps(this);
}

/**
* Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with
* support for assignment and initialization that works in both eager and graph modes.
* <p>
* Initializes the variable with the provided value, and uses it to determin the variables shape and data type.
* <p>
* The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op.
*
* @param scope
* @param initialValue the initial value of the variable.
* @param options carries optional attributes values
* @return a new {@link Variable} instance.
* @see Variable
*/
public <T extends TType> Variable<T> Variable(Operand<T> initialValue,
Variable.Options... options) {
return Variable.create(scope, initialValue, options);
}

/**
* Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with
* support for assignment and initialization that works in both eager and graph modes.
* <p>
* The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op.
*
* @param scope
* @param shape the static shape of the variable.
* @param dataType the data type of the variable.
* @param options carries optional attributes values
* @return a new {@link Variable} instance.
* @see Variable
*/
public <T extends TType> Variable<T> Variable(Shape shape, Class<T> dataType,
Variable.Options... options) {
return Variable.create(scope, shape, dataType, options);
}

/**
* Raise a exception to abort the process when called.
* <p>
Expand Down Expand Up @@ -7947,8 +7984,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand<?> resource) {
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
* @return a new instance of Variable
* @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API.
*/
public <T extends TType> Variable<T> variable(Operand<T> init, Variable.Options... options) {
@Deprecated
public <T extends TType> org.tensorflow.op.core.Variable<T> variable(Operand<T> init,
org.tensorflow.op.core.Variable.Options... options) {
return Helpers.createVariableWithInit(scope, init, options);
}

Expand All @@ -7964,10 +8004,12 @@ public <T extends TType> Variable<T> variable(Operand<T> init, Variable.Options.
* @param dtype The type of elements in the variable tensor.
* @param options carries optional attributes values
* @return a new instance of Variable
* @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API.
*/
public <T extends TType> Variable<T> variable(Shape shape, Class<T> dtype,
Variable.Options... options) {
return Variable.create(scope, shape, dtype, options);
@Deprecated
public <T extends TType> org.tensorflow.op.core.Variable<T> variable(Shape shape, Class<T> dtype,
org.tensorflow.op.core.Variable.Options... options) {
return org.tensorflow.op.core.Variable.create(scope, shape, dtype, options);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ private Options() {
* @param dtype The type of elements in the variable tensor.
* @param options carries optional attributes values
* @return a new instance of Variable
* @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API.
*/
@Endpoint(describeByClass = true)
@Deprecated
public static <T extends TType> Variable<T> create(Scope scope, Shape shape, Class<T> dtype, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable"));
opBuilder = scope.apply(opBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContext;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContext;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -31,8 +34,8 @@
import org.tensorflow.op.Op;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.variable.Variable;

/**
* An environment for executing TensorFlow operations eagerly.
Expand Down Expand Up @@ -290,7 +293,7 @@ public Types environmentType() {
@Override
public boolean isOpEnabled(String opType) {
switch (opType) {
case Variable.OP_NAME:
case org.tensorflow.op.core.Variable.OP_NAME:
case Placeholder.OP_NAME:
case Assign.OP_NAME:
return false;
Expand Down Expand Up @@ -357,6 +360,18 @@ void detach(Pointer... resources) {
}
}

private final Map<String, Variable<?>> variables = new LinkedHashMap<>();

@Override
public void registerVariable(Variable<?> variable) {
variables.put(variable.getName(), variable);
}

@Override
public Map<String, Variable<?>> variables() {
return Collections.unmodifiableMap(variables);
}

private static volatile EagerSession defaultSession = null;

private final WeakPointerScope nativeResources;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
package org.tensorflow;

import org.tensorflow.op.Op;

/**
* Defines an environment for creating and executing TensorFlow {@link Operation}s.
*/
import java.util.Map;
import org.tensorflow.variable.Variable;
/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
public interface ExecutionEnvironment {

enum Types {
Expand Down Expand Up @@ -64,6 +63,13 @@ default boolean isOpEnabled(String opType) {
*/
Types environmentType();

Map<String, Variable<?>> variables();

/**
* Registers a variable with this execution environment. For internal use only.
*/
void registerVariable(Variable<?> variable);

default boolean isEager() {
return environmentType() == Types.EAGER;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
Expand All @@ -57,6 +59,7 @@
import org.tensorflow.proto.util.SaverDef;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;
import org.tensorflow.variable.Variable;


/**
Expand Down Expand Up @@ -478,6 +481,18 @@ synchronized SaverDef saverDef() {

private final List<Op> initializers = new ArrayList<>();

private final Map<String, Variable<?>> variables = new LinkedHashMap<>();

@Override
public void registerVariable(Variable<?> variable) {
variables.put(variable.getName(), variable);
}

@Override
public Map<String, Variable<?>> variables() {
return Collections.unmodifiableMap(variables);
}

// Related native objects (such as the TF_Operation object backing an Operation instance)
// have a validity tied to that of the Graph. The handles to those native objects are not
// valid after Graph.close() has been invoked.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ NameScope withSubScope(String scopeName) {
}

NameScope withName(String name) {
checkPattern(NAME_REGEX, name);
if(name != null) {
checkPattern(NAME_REGEX, name);
}
// All context except for the opName is shared with the new scope.
return new NameScope(opPrefix, name, ids);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ public Scope withSubScope(String childScopeName) {
*
* <p>Names must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
*
* @param opName name for an operator in the returned scope
* <p>{@code opName} may be null, which unsets the name.
*
* @param opName name for an operator in the returned scope. May be null to unset name.
* @return a new Scope that uses opName for operations.
* @throws IllegalArgumentException if the name is invalid
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ private Helpers() {}
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
* @return a new instance of Variable
* @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API.
*/
@Endpoint(name = "variable")
@Deprecated
public static <T extends TType> Variable<T> createVariableWithInit(Scope scope, Operand<T> init, Variable.Options... options) {
Variable<T> newVar = Variable.create(scope, init.shape(), init.type(), options);
Assign<T> assignOp = Assign.create(scope, newVar, init);
Expand Down
Loading