diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
index ea3ef31313e..f08e4da08d3 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
@@ -278,7 +278,6 @@
import org.tensorflow.op.core.Unstage;
import org.tensorflow.op.core.VarHandleOp;
import org.tensorflow.op.core.VarIsInitializedOp;
-import org.tensorflow.op.core.Variable;
import org.tensorflow.op.core.VariableShape;
import org.tensorflow.op.core.Where;
import org.tensorflow.op.core.XlaSpmdFullToShardShape;
@@ -294,6 +293,7 @@
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;
+import org.tensorflow.variable.Variable;
/**
* An API for building operations as {@link Op Op}s
@@ -394,6 +394,43 @@ private Ops(Scope scope) {
quantization = new QuantizationOps(this);
}
+ /**
+ * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with
+ * support for assignment and initialization that works in both eager and graph modes.
+ *
+ * Initializes the variable with the provided value, and uses it to determin the variables shape and data type.
+ *
+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op.
+ *
+ * @param scope
+ * @param initialValue the initial value of the variable.
+ * @param options carries optional attributes values
+ * @return a new {@link Variable} instance.
+ * @see Variable
+ */
+ public Variable Variable(Operand initialValue,
+ Variable.Options... options) {
+ return Variable.create(scope, initialValue, options);
+ }
+
+ /**
+ * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with
+ * support for assignment and initialization that works in both eager and graph modes.
+ *
+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op.
+ *
+ * @param scope
+ * @param shape the static shape of the variable.
+ * @param dataType the data type of the variable.
+ * @param options carries optional attributes values
+ * @return a new {@link Variable} instance.
+ * @see Variable
+ */
+ public Variable Variable(Shape shape, Class dataType,
+ Variable.Options... options) {
+ return Variable.create(scope, shape, dataType, options);
+ }
+
/**
* Raise a exception to abort the process when called.
*
@@ -7947,8 +7984,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand> resource) {
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
* @return a new instance of Variable
+ * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API.
*/
- public Variable variable(Operand init, Variable.Options... options) {
+ @Deprecated
+ public org.tensorflow.op.core.Variable variable(Operand init,
+ org.tensorflow.op.core.Variable.Options... options) {
return Helpers.createVariableWithInit(scope, init, options);
}
@@ -7964,10 +8004,12 @@ public Variable variable(Operand init, Variable.Options.
* @param dtype The type of elements in the variable tensor.
* @param options carries optional attributes values
* @return a new instance of Variable
+ * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API.
*/
- public Variable variable(Shape shape, Class dtype,
- Variable.Options... options) {
- return Variable.create(scope, shape, dtype, options);
+ @Deprecated
+ public org.tensorflow.op.core.Variable variable(Shape shape, Class dtype,
+ org.tensorflow.op.core.Variable.Options... options) {
+ return org.tensorflow.op.core.Variable.create(scope, shape, dtype, options);
}
/**
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java
index 98e545d7b76..1694afd052b 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java
@@ -79,8 +79,10 @@ private Options() {
* @param dtype The type of elements in the variable tensor.
* @param options carries optional attributes values
* @return a new instance of Variable
+ * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API.
*/
@Endpoint(describeByClass = true)
+ @Deprecated
public static Variable create(Scope scope, Shape shape, Class dtype, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable"));
opBuilder = scope.apply(opBuilder);
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
index 8e7465388a8..03618323f65 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java
@@ -21,6 +21,9 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContext;
import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContext;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
@@ -31,8 +34,8 @@
import org.tensorflow.op.Op;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Placeholder;
-import org.tensorflow.op.core.Variable;
import org.tensorflow.proto.framework.ConfigProto;
+import org.tensorflow.variable.Variable;
/**
* An environment for executing TensorFlow operations eagerly.
@@ -290,7 +293,7 @@ public Types environmentType() {
@Override
public boolean isOpEnabled(String opType) {
switch (opType) {
- case Variable.OP_NAME:
+ case org.tensorflow.op.core.Variable.OP_NAME:
case Placeholder.OP_NAME:
case Assign.OP_NAME:
return false;
@@ -357,6 +360,18 @@ void detach(Pointer... resources) {
}
}
+ private final Map> variables = new LinkedHashMap<>();
+
+ @Override
+ public void registerVariable(Variable> variable) {
+ variables.put(variable.getName(), variable);
+ }
+
+ @Override
+ public Map> variables() {
+ return Collections.unmodifiableMap(variables);
+ }
+
private static volatile EagerSession defaultSession = null;
private final WeakPointerScope nativeResources;
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
index d5389bcd0ad..665b3688944 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
@@ -16,10 +16,9 @@
package org.tensorflow;
import org.tensorflow.op.Op;
-
-/**
- * Defines an environment for creating and executing TensorFlow {@link Operation}s.
- */
+import java.util.Map;
+import org.tensorflow.variable.Variable;
+/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
public interface ExecutionEnvironment {
enum Types {
@@ -64,6 +63,13 @@ default boolean isOpEnabled(String opType) {
*/
Types environmentType();
+ Map> variables();
+
+ /**
+ * Registers a variable with this execution environment. For internal use only.
+ */
+ void registerVariable(Variable> variable);
+
default boolean isEager() {
return environmentType() == Types.EAGER;
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
index 0f7a291466c..6dcb4820eed 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -31,7 +31,9 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
@@ -57,6 +59,7 @@
import org.tensorflow.proto.util.SaverDef;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;
+import org.tensorflow.variable.Variable;
/**
@@ -478,6 +481,18 @@ synchronized SaverDef saverDef() {
private final List initializers = new ArrayList<>();
+ private final Map> variables = new LinkedHashMap<>();
+
+ @Override
+ public void registerVariable(Variable> variable) {
+ variables.put(variable.getName(), variable);
+ }
+
+ @Override
+ public Map> variables() {
+ return Collections.unmodifiableMap(variables);
+ }
+
// Related native objects (such as the TF_Operation object backing an Operation instance)
// have a validity tied to that of the Graph. The handles to those native objects are not
// valid after Graph.close() has been invoked.
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java
index 2e84cac1ac7..e7d9db79797 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java
@@ -45,7 +45,9 @@ NameScope withSubScope(String scopeName) {
}
NameScope withName(String name) {
- checkPattern(NAME_REGEX, name);
+ if(name != null) {
+ checkPattern(NAME_REGEX, name);
+ }
// All context except for the opName is shared with the new scope.
return new NameScope(opPrefix, name, ids);
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
index 85e283d9260..87aff4858f8 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
@@ -117,7 +117,9 @@ public Scope withSubScope(String childScopeName) {
*
* Names must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
*
- * @param opName name for an operator in the returned scope
+ *
{@code opName} may be null, which unsets the name.
+ *
+ * @param opName name for an operator in the returned scope. May be null to unset name.
* @return a new Scope that uses opName for operations.
* @throws IllegalArgumentException if the name is invalid
*/
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java
index 59682777966..fafaabf7a2f 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java
@@ -42,8 +42,10 @@ private Helpers() {}
* @param init The op to use to initialise this variable.
* @param options carries optional attributes values
* @return a new instance of Variable
+ * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API.
*/
@Endpoint(name = "variable")
+ @Deprecated
public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) {
Variable newVar = Variable.create(scope, init.shape(), init.type(), options);
Assign assignOp = Assign.create(scope, newVar, init);
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java
new file mode 100644
index 00000000000..ad19909cd6c
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java
@@ -0,0 +1,415 @@
+/*
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.variable;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.function.Supplier;
+import org.tensorflow.ExecutionEnvironment;
+import org.tensorflow.Operand;
+import org.tensorflow.Operation;
+import org.tensorflow.Output;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Operands;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Endpoint;
+import org.tensorflow.op.annotation.Operator;
+import org.tensorflow.op.core.AssignAddVariableOp;
+import org.tensorflow.op.core.AssignSubVariableOp;
+import org.tensorflow.op.core.AssignVariableOp;
+import org.tensorflow.op.core.Init;
+import org.tensorflow.op.core.IsVariableInitialized;
+import org.tensorflow.op.core.ReadVariableOp;
+import org.tensorflow.op.core.VarHandleOp;
+import org.tensorflow.op.core.VarIsInitializedOp;
+import org.tensorflow.proto.framework.DataType;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.family.TType;
+
+/**
+ * A class representing a read-only tensor variable with a constant shape and data type. Analogous to Python's tf.Variable.
+ * Any access will always return the most recent assignment.
+ *
+ * Supports eager and graph mode, and will use {@code VariableV2} in graph mode and enforce ordered assignments.
+ *
+ * Provides methods to get the value and initialize the value if it hasn't already been set.
+ * Also implements {@code Operand} using the stored value.
+ *
+ * Variables will be registered in their execution environment's {@link ExecutionEnvironment#variables()}.
+ *
+ */
+@Operator
+public class Variable implements Operand {
+
+ private final Scope initialScope;
+ private final String name;
+
+ private final Shape shape;
+ private final DataType dataType;
+ private final boolean trainable;
+ private final Class tType;
+
+ private final VarHandleOp handle;
+
+ private VarIsInitializedOp isInitializedOp = null;
+ private Op initializationOp = null;
+ private ReadVariableOp cachedRead = null;
+ private Op lastAssign = null;
+
+ private boolean hasInitialized = false;
+
+ protected Variable(Scope scope, Shape shape, Class dataType,
+ Options[] options) {
+ this.shape = shape;
+ this.dataType = Operands.toDataType(dataType);
+ this.tType = dataType;
+
+ boolean trainable = true;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.trainable != null) {
+ trainable = opts.trainable;
+ }
+ }
+ }
+ this.trainable = trainable;
+
+ this.name = scope.makeOpName("Variable");
+
+ scope = scope.withName(null);
+ this.initialScope = scope.withSubScope(this.name);
+
+
+
+ VarHandleOp.Options[] handleOptions;
+
+ if (scope.env().isGraph()) {
+ handleOptions = new VarHandleOp.Options[]{VarHandleOp.sharedName(this.name)};
+ } else {
+ handleOptions = new VarHandleOp.Options[0];
+ }
+
+ this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, handleOptions);
+
+ scope.env().registerVariable(this);
+ }
+
+ /**
+ * Get whether the variable is trainable (whether it should be updated by optimizers).
+ */
+ public boolean isTrainable() {
+ return trainable;
+ }
+
+ /**
+ * Get the variable handle operation.
+ */
+ public VarHandleOp getHandle() {
+ return handle;
+ }
+
+ /**
+ * Get the variable's constant shape.
+ * This may have unknown dimensions, which do not impose a requirement on the value's dimensions.
+ */
+ public Shape getShape() {
+ return shape;
+ }
+
+ /**
+ * Get the variable's constant data type.
+ */
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ /**
+ * Get whether the variable has had a value assigned to it. This method relates to the Java object, not the graph variable.
+ *
+ * This operation returns true if {@code initialize} or {@code assign} methods have been used on
+ * the variable object, it does not provide any information about the state of the graph variable.
+ * For that, use {@link #isValueInitialized()}
+ */
+ public boolean isInitialized() {
+ return hasInitialized;
+ }
+
+ /**
+ * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op.
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Get the current value as an Output.
+ */
+ @Override
+ public Output asOutput() {
+ return value().asOutput();
+ }
+
+ /**
+ * Get the op of the current value.
+ */
+ @Override
+ public Operation op() {
+ return value().op();
+ }
+
+ /**
+ * Gets the current shape of this variable. May have less unknown dimensions than {@link #getShape()},
+ * in which case they will be filled in from the current value.
+ */
+ @Override
+ public Shape shape() {
+ if(isInitialized()) {
+ return value().shape();
+ } else {
+ return getShape();
+ }
+ }
+
+ /**
+ * Get the current value of this variable, using the given scope.
+ */
+ public synchronized Operand value(Scope scope) {
+ if (!hasInitialized) {
+ throw new IllegalStateException("Variable has not been initialized, can not get.");
+ }
+ ReadVariableOp ret = cachedRead;
+ if (ret == null || scope.env().isEager()) {
+ if (lastAssign != null) {
+ scope = scope.withControlDependencies(Collections.singletonList(lastAssign));
+ }
+ ret = ReadVariableOp.create(scope, handle, tType);
+ }
+ cachedRead = ret;
+ return ret;
+ }
+
+ /**
+ * Get the current value of this variable, using the variable's scope.
+ */
+ public Operand value() {
+ return value(initialScope);
+ }
+
+ private void checkInput(Operand value) {
+ if (!value.shape().isCompatibleWith(this.shape)) {
+ throw new IllegalArgumentException("Shape of new value (" + value.shape() +
+ ") is not compatible with the variable's shape (" + this.shape + ").");
+ }
+
+ if (!tType.isAssignableFrom(value.asOutput().type())) {
+ throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() +
+ ") is not compatible with the variable's data type (" + dataType + ").");
+ }
+ }
+
+ /**
+ * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has.
+ * @param value the value to initialize this variable with.
+ */
+ public synchronized Op initialize(Operand value) {
+ if (hasInitialized) {
+ return initializationOp;
+ }
+ checkInput(value);
+ initializationOp = AssignVariableOp.create(initialScope, handle, value);
+
+ //TODO this if will be unnecessary after the init PR
+ Init.add(initialScope, initializationOp);
+
+ hasInitialized = true;
+ cachedRead = null;
+ return initializationOp;
+ }
+
+ /**
+ * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has.
+ *
+ * The provided function will not be invoked if this function no-ops.
+ * @param value a function returning the value to initialize this variable with.
+ * Will only be called if initialization is done.
+ */
+ public synchronized Op initialize(Supplier> value) {
+ if (hasInitialized) {
+ return initializationOp;
+ }
+ return initialize(value.get());
+ }
+
+ /**
+ * Get whether the graph value is initialized. In eager mode, this will be the same as {@link #isInitialized()}.
+ */
+ public Operand isValueInitialized() {
+ if (isInitializedOp == null || initializationOp.env().isEager()) {
+ isInitializedOp = VarIsInitializedOp.create(initialScope, handle);
+ }
+
+ return isInitializedOp;
+ }
+
+ /**
+ * Assign a new value to this variable using the given scope.
+ *
+ * @param value the value to assign.
+ * @see AssignVariableOp#create
+ */
+ public synchronized Op assign(Scope scope, Operand value) {
+ checkInput(value);
+ lastAssign = AssignVariableOp.create(scope, handle, value);
+ hasInitialized = true;
+ cachedRead = null;
+ return lastAssign;
+ }
+
+ /**
+ * Assign a new value to this variable using the variable's scope.
+ *
+ * @param value the value to assign.
+ * @param controlDependencies any control dependencies of the assignment.
+ * @see AssignVariableOp#create
+ */
+ public Op assign(Operand value, Op... controlDependencies) {
+ return assign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value);
+ }
+
+ /**
+ * Decrement the variable's value by the given value, using the given scope.
+ *
+ * @param value amount to decrease the variable's value by.
+ * @see AssignSubVariableOp#create
+ */
+ public synchronized Op assignSub(Scope scope, Operand value) {
+ if (!hasInitialized) {
+ throw new IllegalStateException("Variable has not been initialized, can not decrement.");
+ }
+
+ if(cachedRead != null)
+ scope = scope.withControlDependencies(Collections.singletonList(cachedRead));
+
+ checkInput(value);
+ lastAssign = AssignSubVariableOp.create(scope, handle, value);
+ hasInitialized = true;
+ cachedRead = null;
+ return lastAssign;
+ }
+
+ /**
+ * Decrement the variable's value by the given value, using the variable's scope.
+ *
+ * @param value amount to decrease the variable's value by.
+ * @param controlDependencies any control dependencies of the assignment.
+ * @see AssignSubVariableOp#create
+ */
+ public Op assignSub(Operand value, Op... controlDependencies) {
+ return assignSub(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value);
+ }
+
+ /**
+ * Increment the variable's value by the given value, using the given scope.
+ *
+ * @param value amount to decrease the variable's value by.
+ * @see AssignAddVariableOp#create
+ */
+ public synchronized Op assignAdd(Scope scope, Operand value) {
+ if (!hasInitialized) {
+ throw new IllegalStateException("Variable has not been initialized, can not increment.");
+ }
+
+ if(cachedRead != null)
+ scope = scope.withControlDependencies(Collections.singletonList(cachedRead));
+
+ checkInput(value);
+ lastAssign = AssignAddVariableOp.create(scope, handle, value);
+ hasInitialized = true;
+ cachedRead = null;
+ return lastAssign;
+ }
+
+ /**
+ * Increment the variable's value by the given value, using the variable's scope.
+ *
+ * @param value amount to decrease the variable's value by.
+ * @param controlDependencies any control dependencies of the assignment.
+ * @see AssignAddVariableOp#create
+ */
+ public Op assignAdd(Operand value, Op... controlDependencies) {
+ return assignAdd(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value);
+ }
+
+ /**
+ * Optional attributes for {@link org.tensorflow.op.core.Variable}
+ */
+ public static class Options {
+
+ /**
+ * @param trainable If non-null, this variable's trainability is as given.
+ * Otherwise, it is trainable.
+ */
+ public Options trainable(boolean trainable) {
+ this.trainable = trainable;
+ return this;
+ }
+
+ Boolean trainable = null;
+
+ private Options() {
+ }
+ }
+
+ /**
+ * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with
+ * support for assignment and initialization that works in both eager and graph modes.
+ *
+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op.
+ * @param scope
+ * @param shape the static shape of the variable.
+ * @param dataType the data type of the variable.
+ * @param options carries optional attributes values
+ * @return a new {@link Variable} instance.
+ * @see Variable
+ */
+ @Endpoint(name = "Variable")
+ public static Variable create(Scope scope, Shape shape, Class dataType, Options... options){
+ return new Variable<>(scope, shape, dataType, options);
+ }
+
+ /**
+ * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with
+ * support for assignment and initialization that works in both eager and graph modes.
+ *
+ * Initializes the variable with the provided value, and uses it to determin the variables shape and data type.
+ *
+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op.
+ * @param scope
+ * @param initialValue the initial value of the variable.
+ * @param options carries optional attributes values
+ * @return a new {@link Variable} instance.
+ * @see Variable
+ */
+ @Endpoint(name = "Variable")
+ public static Variable create(Scope scope, Operand initialValue, Options... options){
+ Variable variable = create(scope, initialValue.shape(), initialValue.type(), options);
+ variable.initialize(initialValue);
+ return variable;
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java
new file mode 100644
index 00000000000..bb73db88775
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java
@@ -0,0 +1,168 @@
+/*
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.Arrays;
+import org.junit.Ignore;
+import org.junit.jupiter.api.Test;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Gradients;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.family.TType;
+import org.tensorflow.variable.Variable;
+
+/**
+ * Unit tests for {@link org.tensorflow.variable.Variable}/{@link Variable}
+ */
+public class VariableTest {
+
+ @Test
+ public void testEager() {
+ try (EagerSession es = EagerSession.create()) {
+ Ops tf = Ops.create(es);
+ Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class);
+
+ assertFalse(variable.isInitialized());
+ assertFalse(variable.isValueInitialized().asTensor().getBoolean());
+
+ variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class));
+
+ assertTrue(variable.isInitialized());
+ assertTrue(variable.isValueInitialized().asTensor().getBoolean());
+
+ assertEquals(1, variable.value().asTensor().getFloat(0, 0));
+
+ variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f)));
+
+ assertEquals(3, variable.value().asTensor().getFloat(0, 0));
+
+ variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class));
+
+ assertEquals(4, variable.value().asTensor().getFloat(0, 0));
+
+ variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class));
+
+ assertEquals(3, variable.value().asTensor().getFloat(0, 0));
+ }
+ }
+
+ @Test
+ public void testGraph() {
+ try (Graph graph = new Graph()) {
+ Ops tf = Ops.create(graph);
+ Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class);
+
+ assertFalse(variable.isInitialized());
+
+ variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class));
+
+ assertTrue(variable.isInitialized());
+
+ Operand original = variable.value();
+
+ Op assign = variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f)));
+ Operand afterAssign = variable.value();
+
+ Op increment = variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class));
+ Operand afterIncrement = variable.value();
+
+ Op decrement = variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class));
+ Operand afterDecrement = variable.value();
+
+ try (Session session = new Session(graph)) {
+
+ assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean());
+
+ session.runInit();
+
+ assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean());
+
+ // test control deps (in-run assign)
+
+ assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0));
+
+ assertEquals(3, ((TFloat32) session.runner().fetch(afterAssign).run().get(0)).getFloat(0, 0));
+
+ assertEquals(4, ((TFloat32) session.runner().fetch(afterIncrement).run().get(0)).getFloat(0, 0));
+
+ assertEquals(3, ((TFloat32) session.runner().fetch(afterDecrement).run().get(0)).getFloat(0, 0));
+
+ // test persistence (multi-run assign)
+
+ session.run(decrement);
+ session.run(decrement);
+
+ assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0));
+
+ session.run(increment);
+ session.run(increment);
+ session.run(increment);
+ session.run(increment);
+
+ assertEquals(5, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0));
+
+ session.run(assign);
+
+ assertEquals(3, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0));
+ }
+ }
+
+ }
+
+ @Test
+ @Ignore // gradient not specified at c++ level: https://github.com/tensorflow/tensorflow/issues/46114.
+ public void gradientTest() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Ops tf = Ops.create(g);
+
+ Placeholder initialValue = tf.placeholder(TFloat32.class);
+ Variable variable = tf.Variable(initialValue);
+
+ Output y0 = tf.math.square(variable.value()).y();
+ Output y1 = tf.math.square(tf.math.square(variable.value())).y();
+
+ Output x = variable.getHandle().asOutput();
+
+ Gradients grads = Gradients.create(tf.scope(), Arrays.asList(y0, y1), Arrays.asList(x));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(1, grads.dy().size());
+
+ try (TFloat32 c = TFloat32.scalarOf(3.0f)) {
+ sess.runner().addTarget(tf.init()).feed(initialValue, c).run();
+ try (AutoCloseableList outputs =
+ new AutoCloseableList<>(sess.runner().feed(initialValue, c).fetch(grads.dy(0)).run())) {
+
+ //TODO expected value may be wrong, check once C++ gradient exists
+ assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
+ }
+ }
+ }
+ }
+
+}