diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ea3ef31313e..f08e4da08d3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -278,7 +278,6 @@ import org.tensorflow.op.core.Unstage; import org.tensorflow.op.core.VarHandleOp; import org.tensorflow.op.core.VarIsInitializedOp; -import org.tensorflow.op.core.Variable; import org.tensorflow.op.core.VariableShape; import org.tensorflow.op.core.Where; import org.tensorflow.op.core.XlaSpmdFullToShardShape; @@ -294,6 +293,7 @@ import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import org.tensorflow.variable.Variable; /** * An API for building operations as {@link Op Op}s @@ -394,6 +394,43 @@ private Ops(Scope scope) { quantization = new QuantizationOps(this); } + /** + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * Initializes the variable with the provided value, and uses it to determin the variables shape and data type. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * + * @param scope + * @param initialValue the initial value of the variable. + * @param options carries optional attributes values + * @return a new {@link Variable} instance. + * @see Variable + */ + public Variable Variable(Operand initialValue, + Variable.Options... options) { + return Variable.create(scope, initialValue, options); + } + + /** + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * + * @param scope + * @param shape the static shape of the variable. + * @param dataType the data type of the variable. + * @param options carries optional attributes values + * @return a new {@link Variable} instance. + * @see Variable + */ + public Variable Variable(Shape shape, Class dataType, + Variable.Options... options) { + return Variable.create(scope, shape, dataType, options); + } + /** * Raise a exception to abort the process when called. *

@@ -7947,8 +7984,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand resource) { * @param init The op to use to initialise this variable. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API. */ - public Variable variable(Operand init, Variable.Options... options) { + @Deprecated + public org.tensorflow.op.core.Variable variable(Operand init, + org.tensorflow.op.core.Variable.Options... options) { return Helpers.createVariableWithInit(scope, init, options); } @@ -7964,10 +8004,12 @@ public Variable variable(Operand init, Variable.Options. * @param dtype The type of elements in the variable tensor. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API. */ - public Variable variable(Shape shape, Class dtype, - Variable.Options... options) { - return Variable.create(scope, shape, dtype, options); + @Deprecated + public org.tensorflow.op.core.Variable variable(Shape shape, Class dtype, + org.tensorflow.op.core.Variable.Options... options) { + return org.tensorflow.op.core.Variable.create(scope, shape, dtype, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java index 98e545d7b76..1694afd052b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java @@ -79,8 +79,10 @@ private Options() { * @param dtype The type of elements in the variable tensor. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API. */ @Endpoint(describeByClass = true) + @Deprecated public static Variable create(Scope scope, Shape shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable")); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 8e7465388a8..03618323f65 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -21,6 +21,9 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContext; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContext; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; @@ -31,8 +34,8 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.variable.Variable; /** * An environment for executing TensorFlow operations eagerly. @@ -290,7 +293,7 @@ public Types environmentType() { @Override public boolean isOpEnabled(String opType) { switch (opType) { - case Variable.OP_NAME: + case org.tensorflow.op.core.Variable.OP_NAME: case Placeholder.OP_NAME: case Assign.OP_NAME: return false; @@ -357,6 +360,18 @@ void detach(Pointer... resources) { } } + private final Map> variables = new LinkedHashMap<>(); + + @Override + public void registerVariable(Variable variable) { + variables.put(variable.getName(), variable); + } + + @Override + public Map> variables() { + return Collections.unmodifiableMap(variables); + } + private static volatile EagerSession defaultSession = null; private final WeakPointerScope nativeResources; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index d5389bcd0ad..665b3688944 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -16,10 +16,9 @@ package org.tensorflow; import org.tensorflow.op.Op; - -/** - * Defines an environment for creating and executing TensorFlow {@link Operation}s. - */ +import java.util.Map; +import org.tensorflow.variable.Variable; +/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ public interface ExecutionEnvironment { enum Types { @@ -64,6 +63,13 @@ default boolean isOpEnabled(String opType) { */ Types environmentType(); + Map> variables(); + + /** + * Registers a variable with this execution environment. For internal use only. + */ + void registerVariable(Variable variable); + default boolean isEager() { return environmentType() == Types.EAGER; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 0f7a291466c..6dcb4820eed 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -31,7 +31,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; @@ -57,6 +59,7 @@ import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; +import org.tensorflow.variable.Variable; /** @@ -478,6 +481,18 @@ synchronized SaverDef saverDef() { private final List initializers = new ArrayList<>(); + private final Map> variables = new LinkedHashMap<>(); + + @Override + public void registerVariable(Variable variable) { + variables.put(variable.getName(), variable); + } + + @Override + public Map> variables() { + return Collections.unmodifiableMap(variables); + } + // Related native objects (such as the TF_Operation object backing an Operation instance) // have a validity tied to that of the Graph. The handles to those native objects are not // valid after Graph.close() has been invoked. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java index 2e84cac1ac7..e7d9db79797 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java @@ -45,7 +45,9 @@ NameScope withSubScope(String scopeName) { } NameScope withName(String name) { - checkPattern(NAME_REGEX, name); + if(name != null) { + checkPattern(NAME_REGEX, name); + } // All context except for the opName is shared with the new scope. return new NameScope(opPrefix, name, ids); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index 85e283d9260..87aff4858f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -117,7 +117,9 @@ public Scope withSubScope(String childScopeName) { * *

Names must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} * - * @param opName name for an operator in the returned scope + *

{@code opName} may be null, which unsets the name. + * + * @param opName name for an operator in the returned scope. May be null to unset name. * @return a new Scope that uses opName for operations. * @throws IllegalArgumentException if the name is invalid */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index 59682777966..fafaabf7a2f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -42,8 +42,10 @@ private Helpers() {} * @param init The op to use to initialise this variable. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API. */ @Endpoint(name = "variable") + @Deprecated public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { Variable newVar = Variable.create(scope, init.shape(), init.type(), options); Assign assignOp = Assign.create(scope, newVar, init); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java new file mode 100644 index 00000000000..ad19909cd6c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -0,0 +1,415 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.variable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Supplier; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.AssignAddVariableOp; +import org.tensorflow.op.core.AssignSubVariableOp; +import org.tensorflow.op.core.AssignVariableOp; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.IsVariableInitialized; +import org.tensorflow.op.core.ReadVariableOp; +import org.tensorflow.op.core.VarHandleOp; +import org.tensorflow.op.core.VarIsInitializedOp; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TType; + +/** + * A class representing a read-only tensor variable with a constant shape and data type. Analogous to Python's tf.Variable. + * Any access will always return the most recent assignment. + *

+ * Supports eager and graph mode, and will use {@code VariableV2} in graph mode and enforce ordered assignments. + *

+ * Provides methods to get the value and initialize the value if it hasn't already been set. + * Also implements {@code Operand} using the stored value. + *

+ * Variables will be registered in their execution environment's {@link ExecutionEnvironment#variables()}. + * + */ +@Operator +public class Variable implements Operand { + + private final Scope initialScope; + private final String name; + + private final Shape shape; + private final DataType dataType; + private final boolean trainable; + private final Class tType; + + private final VarHandleOp handle; + + private VarIsInitializedOp isInitializedOp = null; + private Op initializationOp = null; + private ReadVariableOp cachedRead = null; + private Op lastAssign = null; + + private boolean hasInitialized = false; + + protected Variable(Scope scope, Shape shape, Class dataType, + Options[] options) { + this.shape = shape; + this.dataType = Operands.toDataType(dataType); + this.tType = dataType; + + boolean trainable = true; + if (options != null) { + for (Options opts : options) { + if (opts.trainable != null) { + trainable = opts.trainable; + } + } + } + this.trainable = trainable; + + this.name = scope.makeOpName("Variable"); + + scope = scope.withName(null); + this.initialScope = scope.withSubScope(this.name); + + + + VarHandleOp.Options[] handleOptions; + + if (scope.env().isGraph()) { + handleOptions = new VarHandleOp.Options[]{VarHandleOp.sharedName(this.name)}; + } else { + handleOptions = new VarHandleOp.Options[0]; + } + + this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, handleOptions); + + scope.env().registerVariable(this); + } + + /** + * Get whether the variable is trainable (whether it should be updated by optimizers). + */ + public boolean isTrainable() { + return trainable; + } + + /** + * Get the variable handle operation. + */ + public VarHandleOp getHandle() { + return handle; + } + + /** + * Get the variable's constant shape. + * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. + */ + public Shape getShape() { + return shape; + } + + /** + * Get the variable's constant data type. + */ + public DataType getDataType() { + return dataType; + } + + /** + * Get whether the variable has had a value assigned to it. This method relates to the Java object, not the graph variable. + *

+ * This operation returns true if {@code initialize} or {@code assign} methods have been used on + * the variable object, it does not provide any information about the state of the graph variable. + * For that, use {@link #isValueInitialized()} + */ + public boolean isInitialized() { + return hasInitialized; + } + + /** + * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. + */ + public String getName() { + return name; + } + + /** + * Get the current value as an Output. + */ + @Override + public Output asOutput() { + return value().asOutput(); + } + + /** + * Get the op of the current value. + */ + @Override + public Operation op() { + return value().op(); + } + + /** + * Gets the current shape of this variable. May have less unknown dimensions than {@link #getShape()}, + * in which case they will be filled in from the current value. + */ + @Override + public Shape shape() { + if(isInitialized()) { + return value().shape(); + } else { + return getShape(); + } + } + + /** + * Get the current value of this variable, using the given scope. + */ + public synchronized Operand value(Scope scope) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not get."); + } + ReadVariableOp ret = cachedRead; + if (ret == null || scope.env().isEager()) { + if (lastAssign != null) { + scope = scope.withControlDependencies(Collections.singletonList(lastAssign)); + } + ret = ReadVariableOp.create(scope, handle, tType); + } + cachedRead = ret; + return ret; + } + + /** + * Get the current value of this variable, using the variable's scope. + */ + public Operand value() { + return value(initialScope); + } + + private void checkInput(Operand value) { + if (!value.shape().isCompatibleWith(this.shape)) { + throw new IllegalArgumentException("Shape of new value (" + value.shape() + + ") is not compatible with the variable's shape (" + this.shape + ")."); + } + + if (!tType.isAssignableFrom(value.asOutput().type())) { + throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + + ") is not compatible with the variable's data type (" + dataType + ")."); + } + } + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + * @param value the value to initialize this variable with. + */ + public synchronized Op initialize(Operand value) { + if (hasInitialized) { + return initializationOp; + } + checkInput(value); + initializationOp = AssignVariableOp.create(initialScope, handle, value); + + //TODO this if will be unnecessary after the init PR + Init.add(initialScope, initializationOp); + + hasInitialized = true; + cachedRead = null; + return initializationOp; + } + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + *

+ * The provided function will not be invoked if this function no-ops. + * @param value a function returning the value to initialize this variable with. + * Will only be called if initialization is done. + */ + public synchronized Op initialize(Supplier> value) { + if (hasInitialized) { + return initializationOp; + } + return initialize(value.get()); + } + + /** + * Get whether the graph value is initialized. In eager mode, this will be the same as {@link #isInitialized()}. + */ + public Operand isValueInitialized() { + if (isInitializedOp == null || initializationOp.env().isEager()) { + isInitializedOp = VarIsInitializedOp.create(initialScope, handle); + } + + return isInitializedOp; + } + + /** + * Assign a new value to this variable using the given scope. + * + * @param value the value to assign. + * @see AssignVariableOp#create + */ + public synchronized Op assign(Scope scope, Operand value) { + checkInput(value); + lastAssign = AssignVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; + } + + /** + * Assign a new value to this variable using the variable's scope. + * + * @param value the value to assign. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignVariableOp#create + */ + public Op assign(Operand value, Op... controlDependencies) { + return assign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } + + /** + * Decrement the variable's value by the given value, using the given scope. + * + * @param value amount to decrease the variable's value by. + * @see AssignSubVariableOp#create + */ + public synchronized Op assignSub(Scope scope, Operand value) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not decrement."); + } + + if(cachedRead != null) + scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); + + checkInput(value); + lastAssign = AssignSubVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; + } + + /** + * Decrement the variable's value by the given value, using the variable's scope. + * + * @param value amount to decrease the variable's value by. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignSubVariableOp#create + */ + public Op assignSub(Operand value, Op... controlDependencies) { + return assignSub(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } + + /** + * Increment the variable's value by the given value, using the given scope. + * + * @param value amount to decrease the variable's value by. + * @see AssignAddVariableOp#create + */ + public synchronized Op assignAdd(Scope scope, Operand value) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not increment."); + } + + if(cachedRead != null) + scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); + + checkInput(value); + lastAssign = AssignAddVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; + } + + /** + * Increment the variable's value by the given value, using the variable's scope. + * + * @param value amount to decrease the variable's value by. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignAddVariableOp#create + */ + public Op assignAdd(Operand value, Op... controlDependencies) { + return assignAdd(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.Variable} + */ + public static class Options { + + /** + * @param trainable If non-null, this variable's trainability is as given. + * Otherwise, it is trainable. + */ + public Options trainable(boolean trainable) { + this.trainable = trainable; + return this; + } + + Boolean trainable = null; + + private Options() { + } + } + + /** + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * @param scope + * @param shape the static shape of the variable. + * @param dataType the data type of the variable. + * @param options carries optional attributes values + * @return a new {@link Variable} instance. + * @see Variable + */ + @Endpoint(name = "Variable") + public static Variable create(Scope scope, Shape shape, Class dataType, Options... options){ + return new Variable<>(scope, shape, dataType, options); + } + + /** + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * Initializes the variable with the provided value, and uses it to determin the variables shape and data type. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * @param scope + * @param initialValue the initial value of the variable. + * @param options carries optional attributes values + * @return a new {@link Variable} instance. + * @see Variable + */ + @Endpoint(name = "Variable") + public static Variable create(Scope scope, Operand initialValue, Options... options){ + Variable variable = create(scope, initialValue.shape(), initialValue.type(), options); + variable.initialize(initialValue); + return variable; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java new file mode 100644 index 00000000000..bb73db88775 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java @@ -0,0 +1,168 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import org.junit.Ignore; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Gradients; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; +import org.tensorflow.variable.Variable; + +/** + * Unit tests for {@link org.tensorflow.variable.Variable}/{@link Variable} + */ +public class VariableTest { + + @Test + public void testEager() { + try (EagerSession es = EagerSession.create()) { + Ops tf = Ops.create(es); + Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class); + + assertFalse(variable.isInitialized()); + assertFalse(variable.isValueInitialized().asTensor().getBoolean()); + + variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertTrue(variable.isInitialized()); + assertTrue(variable.isValueInitialized().asTensor().getBoolean()); + + assertEquals(1, variable.value().asTensor().getFloat(0, 0)); + + variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); + + assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + + variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertEquals(4, variable.value().asTensor().getFloat(0, 0)); + + variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + } + } + + @Test + public void testGraph() { + try (Graph graph = new Graph()) { + Ops tf = Ops.create(graph); + Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class); + + assertFalse(variable.isInitialized()); + + variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertTrue(variable.isInitialized()); + + Operand original = variable.value(); + + Op assign = variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); + Operand afterAssign = variable.value(); + + Op increment = variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); + Operand afterIncrement = variable.value(); + + Op decrement = variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); + Operand afterDecrement = variable.value(); + + try (Session session = new Session(graph)) { + + assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean()); + + session.runInit(); + + assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean()); + + // test control deps (in-run assign) + + assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + assertEquals(3, ((TFloat32) session.runner().fetch(afterAssign).run().get(0)).getFloat(0, 0)); + + assertEquals(4, ((TFloat32) session.runner().fetch(afterIncrement).run().get(0)).getFloat(0, 0)); + + assertEquals(3, ((TFloat32) session.runner().fetch(afterDecrement).run().get(0)).getFloat(0, 0)); + + // test persistence (multi-run assign) + + session.run(decrement); + session.run(decrement); + + assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + session.run(increment); + session.run(increment); + session.run(increment); + session.run(increment); + + assertEquals(5, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + session.run(assign); + + assertEquals(3, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + } + } + + } + + @Test + @Ignore // gradient not specified at c++ level: https://github.com/tensorflow/tensorflow/issues/46114. + public void gradientTest() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Ops tf = Ops.create(g); + + Placeholder initialValue = tf.placeholder(TFloat32.class); + Variable variable = tf.Variable(initialValue); + + Output y0 = tf.math.square(variable.value()).y(); + Output y1 = tf.math.square(tf.math.square(variable.value())).y(); + + Output x = variable.getHandle().asOutput(); + + Gradients grads = Gradients.create(tf.scope(), Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (TFloat32 c = TFloat32.scalarOf(3.0f)) { + sess.runner().addTarget(tf.init()).feed(initialValue, c).run(); + try (AutoCloseableList outputs = + new AutoCloseableList<>(sess.runner().feed(initialValue, c).fetch(grads.dy(0)).run())) { + + //TODO expected value may be wrong, check once C++ gradient exists + assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + } + } + } + } + +}