From e4cb2fbf705a9cf54ab39518aa6239c681c340ec Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 9 Dec 2020 16:54:17 -0800 Subject: [PATCH 01/17] Start of tf.Variable style variable class Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 29 +++- .../java/org/tensorflow/op/core/Variable.java | 2 + .../java/org/tensorflow/EagerSession.java | 16 ++ .../org/tensorflow/ExecutionEnvironment.java | 16 +- .../src/main/java/org/tensorflow/Graph.java | 15 ++ .../java/org/tensorflow/op/NameScope.java | 4 +- .../main/java/org/tensorflow/op/Scope.java | 4 +- .../java/org/tensorflow/op/core/Helpers.java | 2 + .../tensorflow/variable/EagerVariable.java | 50 ++++++ .../tensorflow/variable/GraphVariable.java | 56 +++++++ .../org/tensorflow/variable/Variable.java | 143 ++++++++++++++++++ 11 files changed, 326 insertions(+), 11 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ea3ef31313e..e8689d8ae3d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -278,7 +278,6 @@ import org.tensorflow.op.core.Unstage; import org.tensorflow.op.core.VarHandleOp; import org.tensorflow.op.core.VarIsInitializedOp; -import org.tensorflow.op.core.Variable; import org.tensorflow.op.core.VariableShape; import org.tensorflow.op.core.Where; import org.tensorflow.op.core.XlaSpmdFullToShardShape; @@ -294,6 +293,7 @@ import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import org.tensorflow.variable.Variable; /** * An API for building operations as {@link Op Op}s @@ -394,6 +394,20 @@ private Ops(Scope scope) { quantization = new QuantizationOps(this); } + /** + * empty + */ + public Variable Variable(Operand initialValue) { + return Variable.create(scope, initialValue); + } + + /** + * empty + */ + public Variable Variable(Shape shape, DataType dataType) { + return Variable.create(scope, shape, dataType); + } + /** * Raise a exception to abort the process when called. *

@@ -7947,8 +7961,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand resource) { * @param init The op to use to initialise this variable. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API. */ - public Variable variable(Operand init, Variable.Options... options) { + @Deprecated + public org.tensorflow.op.core.Variable variable(Operand init, + org.tensorflow.op.core.Variable.Options... options) { return Helpers.createVariableWithInit(scope, init, options); } @@ -7964,10 +7981,12 @@ public Variable variable(Operand init, Variable.Options. * @param dtype The type of elements in the variable tensor. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API. */ - public Variable variable(Shape shape, Class dtype, - Variable.Options... options) { - return Variable.create(scope, shape, dtype, options); + @Deprecated + public org.tensorflow.op.core.Variable variable(Shape shape, + Class dtype, org.tensorflow.op.core.Variable.Options... options) { + return org.tensorflow.op.core.Variable.create(scope, shape, dtype, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java index 98e545d7b76..1694afd052b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Variable.java @@ -79,8 +79,10 @@ private Options() { * @param dtype The type of elements in the variable tensor. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API. */ @Endpoint(describeByClass = true) + @Deprecated public static Variable create(Scope scope, Shape shape, Class dtype, Options... options) { OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable")); opBuilder = scope.apply(opBuilder); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 8e7465388a8..cff25316e0d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -21,6 +21,9 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TFE_DeleteContext; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_NewContext; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; @@ -33,6 +36,7 @@ import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.variable.Variable; /** * An environment for executing TensorFlow operations eagerly. @@ -357,6 +361,18 @@ void detach(Pointer... resources) { } } + private final Map> variables = new LinkedHashMap<>(); + + @Override + public void registerVariable(Variable variable) { + variables.put(variable.getName(), variable); + } + + @Override + public Map> variables() { + return Collections.unmodifiableMap(variables); + } + private static volatile EagerSession defaultSession = null; private final WeakPointerScope nativeResources; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index d5389bcd0ad..b7fc526ec04 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -16,10 +16,9 @@ package org.tensorflow; import org.tensorflow.op.Op; - -/** - * Defines an environment for creating and executing TensorFlow {@link Operation}s. - */ +import java.util.Map; +import org.tensorflow.variable.Variable; +/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ public interface ExecutionEnvironment { enum Types { @@ -64,6 +63,15 @@ default boolean isOpEnabled(String opType) { */ Types environmentType(); + Map> variables(); + + /** + * Registers a variable with this execution environment. + * @deprecated Done automatically in Variable's constructor, should only be used internally. + */ + @Deprecated + void registerVariable(Variable variable); + default boolean isEager() { return environmentType() == Types.EAGER; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 0f7a291466c..6dcb4820eed 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -31,7 +31,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; @@ -57,6 +59,7 @@ import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; +import org.tensorflow.variable.Variable; /** @@ -478,6 +481,18 @@ synchronized SaverDef saverDef() { private final List initializers = new ArrayList<>(); + private final Map> variables = new LinkedHashMap<>(); + + @Override + public void registerVariable(Variable variable) { + variables.put(variable.getName(), variable); + } + + @Override + public Map> variables() { + return Collections.unmodifiableMap(variables); + } + // Related native objects (such as the TF_Operation object backing an Operation instance) // have a validity tied to that of the Graph. The handles to those native objects are not // valid after Graph.close() has been invoked. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java index 2e84cac1ac7..e7d9db79797 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java @@ -45,7 +45,9 @@ NameScope withSubScope(String scopeName) { } NameScope withName(String name) { - checkPattern(NAME_REGEX, name); + if(name != null) { + checkPattern(NAME_REGEX, name); + } // All context except for the opName is shared with the new scope. return new NameScope(opPrefix, name, ids); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index 85e283d9260..87aff4858f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -117,7 +117,9 @@ public Scope withSubScope(String childScopeName) { * *

Names must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} * - * @param opName name for an operator in the returned scope + *

{@code opName} may be null, which unsets the name. + * + * @param opName name for an operator in the returned scope. May be null to unset name. * @return a new Scope that uses opName for operations. * @throws IllegalArgumentException if the name is invalid */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index 59682777966..fafaabf7a2f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -42,8 +42,10 @@ private Helpers() {} * @param init The op to use to initialise this variable. * @param options carries optional attributes values * @return a new instance of Variable + * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Operand)} instead for a tf.Variable like API. */ @Endpoint(name = "variable") + @Deprecated public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { Variable newVar = Variable.create(scope, init.shape(), init.type(), options); Assign assignOp = Assign.create(scope, newVar, init); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java new file mode 100644 index 00000000000..6bc1c77a585 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java @@ -0,0 +1,50 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.variable; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.types.family.TType; + +class EagerVariable extends Variable { + + private Operand value = null; + + EagerVariable(Scope scope, Shape shape, DataType dtype) { + super(scope, shape, dtype); + } + + @Override + protected Operand getValue() { + if(value == null){ + throw new IllegalStateException("Value has not been initialized."); + } + return value; + } + + @Override + protected void doInitialize(Scope scope, Operand value) { + this.value = value; + } + + @Override + protected void doAssign(Scope scope, Operand value) { + this.value = value; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java new file mode 100644 index 00000000000..3e734829f95 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java @@ -0,0 +1,56 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.variable; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Init; +import org.tensorflow.types.family.TType; + +class GraphVariable extends Variable { + + private final org.tensorflow.op.core.Variable variable; + private Operand get; + + GraphVariable(Scope scope, Shape shape, DataType dataType) { + super(scope, shape, dataType); + variable = org.tensorflow.op.core.Variable.create(scope, shape, dataType); + } + + @Override + protected Operand getValue() { + if(get == null){ + throw new IllegalStateException("Variable has not been initialized."); + } + return get; + } + + @Override + protected void doInitialize(Scope scope, Operand value) { + Assign assignOp = Assign.create(scope, variable, value); + Init.add(scope, assignOp); + get = assignOp; + } + + @Override + protected void doAssign(Scope scope, Operand value) { + get = Assign.create(scope, variable, value); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java new file mode 100644 index 00000000000..1e024d45d57 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -0,0 +1,143 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.variable; + +import java.util.Arrays; +import org.tensorflow.DataType; +import org.tensorflow.EagerSession; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +@Operator +public abstract class Variable implements Operand { + private final Scope initialScope; + private final String name; + + protected final Shape shape; + protected final DataType dataType; + + protected boolean hasInitialized = false; + + protected Variable(Scope scope, Shape shape, DataType dataType){ + this.initialScope = scope.withName(null); + + this.shape = shape; + this.dataType = dataType; + this.name = scope.makeOpName("Variable"); + scope.env().registerVariable(this); + } + + public Shape getShape() { + return shape; + } + + public DataType getDataType() { + return dataType; + } + + public boolean isInitialized() { + return hasInitialized; + } + + public String getName() { + return name; + } + + public Operand value(){ + if(!hasInitialized){ + throw new IllegalStateException("Variable has not been initialized, can not get."); + } + return getValue(); + } + + private void checkInput(Operand value){ + if(value.shape().isCompatibleWith(this.shape)){ + throw new IllegalArgumentException("Shape of new value (" + value.shape() + + ") is not compatible with the variable's shape (" + this.shape + ")."); + } + //TODO better checking w/ new types after refactor + if(value.asOutput().dataType() != dataType){ + throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + + ") is not compatible with the variable's data type (" + dataType + ")."); + } + } + + public Operand initialize(Operand value){ + if(hasInitialized){ + throw new IllegalStateException("Variable has already been initialized, can't initialize again."); + } + checkInput(value); + doInitialize(initialScope, value); + hasInitialized = true; + return value(); + } + + public Operand assign(Operand value, Op... controlDependencies){ + checkInput(value); + doAssign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + hasInitialized = true; + return value(); + } + + protected abstract Operand getValue(); + protected abstract void doInitialize(Scope scope, Operand value); + protected abstract void doAssign(Scope scope, Operand value); + + @Override + public Output asOutput() { + return value().asOutput(); + } + + @Override + public Operation op() { + return value().op(); + } + + @Override + public Shape shape() { + if(isInitialized()) { + return value().shape(); + } else { + return getShape(); + } + } + + @Endpoint(name = "Variable") + public static Variable create(Scope scope, Shape shape, DataType dataType){ + if(scope.env().isEager()) { + return new EagerVariable<>(scope, shape, dataType); + } else { + return new GraphVariable<>(scope, shape, dataType); + } + } + + @Endpoint(name = "Variable") + public static Variable create(Scope scope, Operand initialValue){ + Variable variable = create(scope, initialValue.shape(), initialValue.asOutput().dataType()); + variable.initialize(variable); + return variable; + } +} From 52c23fea9978020b707a081263d1f1d020bac8aa Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 9 Dec 2020 18:15:15 -0800 Subject: [PATCH 02/17] Javadoc, change initialize to no-op if already done (for function version) Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 23 ++++- .../tensorflow/variable/GraphVariable.java | 7 +- .../org/tensorflow/variable/Variable.java | 92 ++++++++++++++++++- 3 files changed, 118 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index e8689d8ae3d..08125f84e6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -395,14 +395,33 @@ private Ops(Scope scope) { } /** - * empty + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * Initializes the variable with the provided value, and uses it to determin the variables shape and data type. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * + * @param scope + * @param initialValue the initial value of the variable. + * @return a new {@link Variable} instance. + * @see Variable */ public Variable Variable(Operand initialValue) { return Variable.create(scope, initialValue); } /** - * empty + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * + * @param scope + * @param shape the static shape of the variable. + * @param dataType the data type of the variable. + * @return a new {@link Variable} instance. + * @see Variable */ public Variable Variable(Shape shape, DataType dataType) { return Variable.create(scope, shape, dataType); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java index 3e734829f95..2f171a5882a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java @@ -16,6 +16,7 @@ */ package org.tensorflow.variable; +import java.util.Collections; import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; @@ -27,7 +28,7 @@ class GraphVariable extends Variable { private final org.tensorflow.op.core.Variable variable; - private Operand get; + private Operand get = null; GraphVariable(Scope scope, Shape shape, DataType dataType) { super(scope, shape, dataType); @@ -51,6 +52,10 @@ protected void doInitialize(Scope scope, Operand value) { @Override protected void doAssign(Scope scope, Operand value) { + if(get != null){ + scope = scope.withControlDependencies(Collections.singletonList(get)); + } + get = Assign.create(scope, variable, value); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 1e024d45d57..e3777d256bb 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -17,6 +17,8 @@ package org.tensorflow.variable; import java.util.Arrays; +import java.util.function.Consumer; +import java.util.function.Supplier; import org.tensorflow.DataType; import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; @@ -31,6 +33,18 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; +/** + * A class representing a mutable tensor variable with a constant shape and data type. Analogous to Python's tf.Variable. + * Any access will always return the most recent assignment. + *

+ * Supports eager and graph mode, and will use {@code VariableV2} in graph mode and enforce ordered assignments. + *

+ * Provides methods to get the value, assign a new value, and initialize the value if it hasn't already been set. + * Also implements {@code Operand} using the stored value. + * The exposed value will not usually be a {@link org.tensorflow.op.core.Variable}. + *

+ * Variables will be registered in their execution enviroment's {@link ExecutionEnvironment#variables()}. + */ @Operator public abstract class Variable implements Operand { private final Scope initialScope; @@ -50,22 +64,38 @@ protected Variable(Scope scope, Shape shape, DataType dataType){ scope.env().registerVariable(this); } + /** + * Get the variable's constant shape. + * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. + */ public Shape getShape() { return shape; } + /** + * Get the variable's constant data type. + */ public DataType getDataType() { return dataType; } + /** + * Get whether the variable has had a value assigned to it. + */ public boolean isInitialized() { return hasInitialized; } + /** + * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. + */ public String getName() { return name; } + /** + * Get the current value of this variable. + */ public Operand value(){ if(!hasInitialized){ throw new IllegalStateException("Variable has not been initialized, can not get."); @@ -85,9 +115,14 @@ private void checkInput(Operand value){ } } + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + * @param value the value to initialize this variable with. + * @return the new value (or current if it was already initialized). + */ public Operand initialize(Operand value){ if(hasInitialized){ - throw new IllegalStateException("Variable has already been initialized, can't initialize again."); + return value(); } checkInput(value); doInitialize(initialScope, value); @@ -95,6 +130,28 @@ public Operand initialize(Operand value){ return value(); } + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + *

+ * The provided function will not be invoked if this function no-ops. + * @param value a function returning the value to initialize this variable with. + * Will only be called if initialization is done. + * @return the new value (or current if it was already initialized). + */ + public Operand initialize(Supplier> value){ + if(hasInitialized){ + return value(); + } + return initialize(value.get()); + } + + /** + * Assign a new value to this variable. + * @param value the value to assign. + * @param controlDependencies any control dependencies of the assignment. + * @return the new value + */ public Operand assign(Operand value, Op... controlDependencies){ checkInput(value); doAssign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); @@ -106,16 +163,26 @@ public Operand assign(Operand value, Op... controlDependencies){ protected abstract void doInitialize(Scope scope, Operand value); protected abstract void doAssign(Scope scope, Operand value); + /** + * Get the current value as an Output. + */ @Override public Output asOutput() { return value().asOutput(); } + /** + * Get the op of the current value. + */ @Override public Operation op() { return value().op(); } + /** + * Gets the current shape of this variable. May have less unknown dimensions than {@link #getShape()}, + * in which case they will be filled in from the current value. + */ @Override public Shape shape() { if(isInitialized()) { @@ -125,6 +192,17 @@ public Shape shape() { } } + /** + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * @param scope + * @param shape the static shape of the variable. + * @param dataType the data type of the variable. + * @return a new {@link Variable} instance. + * @see Variable + */ @Endpoint(name = "Variable") public static Variable create(Scope scope, Shape shape, DataType dataType){ if(scope.env().isEager()) { @@ -134,6 +212,18 @@ public static Variable create(Scope scope, Shape shape, Dat } } + /** + * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with + * support for assignment and initialization that works in both eager and graph modes. + *

+ * Initializes the variable with the provided value, and uses it to determin the variables shape and data type. + *

+ * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. + * @param scope + * @param initialValue the initial value of the variable. + * @return a new {@link Variable} instance. + * @see Variable + */ @Endpoint(name = "Variable") public static Variable create(Scope scope, Operand initialValue){ Variable variable = create(scope, initialValue.shape(), initialValue.asOutput().dataType()); From 3c75f276dbebc1cc5beff2829d8f332b27a3f1fb Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 28 Dec 2020 20:36:25 -0800 Subject: [PATCH 03/17] Change to immutable by default Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerSession.java | 8 +- .../org/tensorflow/ExecutionEnvironment.java | 6 +- .../src/main/java/org/tensorflow/Graph.java | 8 +- .../tensorflow/variable/EagerVariable.java | 2 +- .../tensorflow/variable/GraphVariable.java | 2 +- .../tensorflow/variable/MutableVariable.java | 155 ++++++++++++++++++ .../org/tensorflow/variable/Variable.java | 110 ++----------- 7 files changed, 186 insertions(+), 105 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index cff25316e0d..4c5ac08b519 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -36,7 +36,7 @@ import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; -import org.tensorflow.variable.Variable; +import org.tensorflow.variable.MutableVariable; /** * An environment for executing TensorFlow operations eagerly. @@ -361,15 +361,15 @@ void detach(Pointer... resources) { } } - private final Map> variables = new LinkedHashMap<>(); + private final Map> variables = new LinkedHashMap<>(); @Override - public void registerVariable(Variable variable) { + public void registerVariable(MutableVariable variable) { variables.put(variable.getName(), variable); } @Override - public Map> variables() { + public Map> variables() { return Collections.unmodifiableMap(variables); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index b7fc526ec04..0cfcb4fb598 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -17,7 +17,7 @@ import org.tensorflow.op.Op; import java.util.Map; -import org.tensorflow.variable.Variable; +import org.tensorflow.variable.MutableVariable; /** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ public interface ExecutionEnvironment { @@ -63,14 +63,14 @@ default boolean isOpEnabled(String opType) { */ Types environmentType(); - Map> variables(); + Map> variables(); /** * Registers a variable with this execution environment. * @deprecated Done automatically in Variable's constructor, should only be used internally. */ @Deprecated - void registerVariable(Variable variable); + void registerVariable(MutableVariable variable); default boolean isEager() { return environmentType() == Types.EAGER; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 6dcb4820eed..e3e5dd058a6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -59,7 +59,7 @@ import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; -import org.tensorflow.variable.Variable; +import org.tensorflow.variable.MutableVariable; /** @@ -481,15 +481,15 @@ synchronized SaverDef saverDef() { private final List initializers = new ArrayList<>(); - private final Map> variables = new LinkedHashMap<>(); + private final Map> variables = new LinkedHashMap<>(); @Override - public void registerVariable(Variable variable) { + public void registerVariable(MutableVariable variable) { variables.put(variable.getName(), variable); } @Override - public Map> variables() { + public Map> variables() { return Collections.unmodifiableMap(variables); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java index 6bc1c77a585..e76d9c52c72 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java @@ -22,7 +22,7 @@ import org.tensorflow.op.Scope; import org.tensorflow.types.family.TType; -class EagerVariable extends Variable { +class EagerVariable extends MutableVariable { private Operand value = null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java index 2f171a5882a..afb59915f62 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java @@ -25,7 +25,7 @@ import org.tensorflow.op.core.Init; import org.tensorflow.types.family.TType; -class GraphVariable extends Variable { +class GraphVariable extends MutableVariable { private final org.tensorflow.op.core.Variable variable; private Operand get = null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java new file mode 100644 index 00000000000..fd424b03364 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -0,0 +1,155 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.variable; + +import java.util.Arrays; +import java.util.function.Supplier; +import org.tensorflow.DataType; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * The implementation of {@link Variable}, with mutation methods. + * + * @see Variable + */ +public abstract class MutableVariable implements Variable { + private final Scope initialScope; + private final String name; + + protected final Shape shape; + protected final DataType dataType; + + protected boolean hasInitialized = false; + + protected MutableVariable(Scope scope, Shape shape, DataType dataType){ + this.initialScope = scope.withName(null); + + this.shape = shape; + this.dataType = dataType; + this.name = scope.makeOpName("Variable"); + scope.env().registerVariable(this); + } + + /** + * Get the variable's constant shape. + * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. + */ + public Shape getShape() { + return shape; + } + + /** + * Get the variable's constant data type. + */ + public DataType getDataType() { + return dataType; + } + + /** + * Get whether the variable has had a value assigned to it. + */ + public boolean isInitialized() { + return hasInitialized; + } + + /** + * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. + */ + public String getName() { + return name; + } + + /** + * Get the current value of this variable. + */ + public Operand value(){ + if(!hasInitialized){ + throw new IllegalStateException("Variable has not been initialized, can not get."); + } + return getValue(); + } + + private void checkInput(Operand value){ + if(value.shape().isCompatibleWith(this.shape)){ + throw new IllegalArgumentException("Shape of new value (" + value.shape() + + ") is not compatible with the variable's shape (" + this.shape + ")."); + } + //TODO better checking w/ new types after refactor + if(value.asOutput().dataType() != dataType){ + throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + + ") is not compatible with the variable's data type (" + dataType + ")."); + } + } + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + * @param value the value to initialize this variable with. + * @return the new value (or current if it was already initialized). + */ + public Operand initialize(Operand value){ + if(hasInitialized){ + return value(); + } + checkInput(value); + doInitialize(initialScope, value); + hasInitialized = true; + return value(); + } + + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + *

+ * The provided function will not be invoked if this function no-ops. + * @param value a function returning the value to initialize this variable with. + * Will only be called if initialization is done. + * @return the new value (or current if it was already initialized). + */ + public Operand initialize(Supplier> value){ + if(hasInitialized){ + return value(); + } + return initialize(value.get()); + } + + /** + * Assign a new value to this variable. + * @param value the value to assign. + * @param controlDependencies any control dependencies of the assignment. + * @return the new value + */ + public Operand assign(Operand value, Op... controlDependencies){ + checkInput(value); + doAssign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + hasInitialized = true; + return value(); + } + + protected abstract Operand getValue(); + protected abstract void doInitialize(Scope scope, Operand value); + protected abstract void doAssign(Scope scope, Operand value); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index e3777d256bb..1ce73e9ca02 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -16,120 +16,68 @@ */ package org.tensorflow.variable; -import java.util.Arrays; -import java.util.function.Consumer; import java.util.function.Supplier; import org.tensorflow.DataType; -import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; -import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; /** - * A class representing a mutable tensor variable with a constant shape and data type. Analogous to Python's tf.Variable. + * A class representing a read-only tensor variable with a constant shape and data type. Analogous to Python's tf.Variable. * Any access will always return the most recent assignment. *

* Supports eager and graph mode, and will use {@code VariableV2} in graph mode and enforce ordered assignments. *

- * Provides methods to get the value, assign a new value, and initialize the value if it hasn't already been set. + * Provides methods to get the value and initialize the value if it hasn't already been set. * Also implements {@code Operand} using the stored value. * The exposed value will not usually be a {@link org.tensorflow.op.core.Variable}. *

- * Variables will be registered in their execution enviroment's {@link ExecutionEnvironment#variables()}. + * Variables will be registered in their execution environment's {@link ExecutionEnvironment#variables()}. + *

+ * Implemented by {@link MutableVariable}, which provides mutability. + * + * @see MutableVariable */ @Operator -public abstract class Variable implements Operand { - private final Scope initialScope; - private final String name; - - protected final Shape shape; - protected final DataType dataType; - - protected boolean hasInitialized = false; - - protected Variable(Scope scope, Shape shape, DataType dataType){ - this.initialScope = scope.withName(null); - - this.shape = shape; - this.dataType = dataType; - this.name = scope.makeOpName("Variable"); - scope.env().registerVariable(this); - } - +public interface Variable extends Operand { /** * Get the variable's constant shape. * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. */ - public Shape getShape() { - return shape; - } + Shape getShape(); /** * Get the variable's constant data type. */ - public DataType getDataType() { - return dataType; - } + DataType getDataType(); /** * Get whether the variable has had a value assigned to it. */ - public boolean isInitialized() { - return hasInitialized; - } + boolean isInitialized(); /** * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. */ - public String getName() { - return name; - } + String getName(); /** * Get the current value of this variable. */ - public Operand value(){ - if(!hasInitialized){ - throw new IllegalStateException("Variable has not been initialized, can not get."); - } - return getValue(); - } - - private void checkInput(Operand value){ - if(value.shape().isCompatibleWith(this.shape)){ - throw new IllegalArgumentException("Shape of new value (" + value.shape() + - ") is not compatible with the variable's shape (" + this.shape + ")."); - } - //TODO better checking w/ new types after refactor - if(value.asOutput().dataType() != dataType){ - throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + - ") is not compatible with the variable's data type (" + dataType + ")."); - } - } + Operand value(); /** * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. * @param value the value to initialize this variable with. * @return the new value (or current if it was already initialized). */ - public Operand initialize(Operand value){ - if(hasInitialized){ - return value(); - } - checkInput(value); - doInitialize(initialScope, value); - hasInitialized = true; - return value(); - } - + Operand initialize(Operand value); /** * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. @@ -139,35 +87,13 @@ public Operand initialize(Operand value){ * Will only be called if initialization is done. * @return the new value (or current if it was already initialized). */ - public Operand initialize(Supplier> value){ - if(hasInitialized){ - return value(); - } - return initialize(value.get()); - } - - /** - * Assign a new value to this variable. - * @param value the value to assign. - * @param controlDependencies any control dependencies of the assignment. - * @return the new value - */ - public Operand assign(Operand value, Op... controlDependencies){ - checkInput(value); - doAssign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); - hasInitialized = true; - return value(); - } - - protected abstract Operand getValue(); - protected abstract void doInitialize(Scope scope, Operand value); - protected abstract void doAssign(Scope scope, Operand value); + Operand initialize(Supplier> value); /** * Get the current value as an Output. */ @Override - public Output asOutput() { + default Output asOutput() { return value().asOutput(); } @@ -175,7 +101,7 @@ public Output asOutput() { * Get the op of the current value. */ @Override - public Operation op() { + default Operation op() { return value().op(); } @@ -184,7 +110,7 @@ public Operation op() { * in which case they will be filled in from the current value. */ @Override - public Shape shape() { + default Shape shape() { if(isInitialized()) { return value().shape(); } else { From c5ad2b3f189bce5b194c13fbd7b088414002a2f2 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 28 Dec 2020 21:01:52 -0800 Subject: [PATCH 04/17] todo Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/variable/MutableVariable.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index fd424b03364..41dd2db4346 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -45,6 +45,8 @@ public abstract class MutableVariable implements Variable { protected boolean hasInitialized = false; + //TODO use the new resource API. + protected MutableVariable(Scope scope, Shape shape, DataType dataType){ this.initialScope = scope.withName(null); From 19e6f3c3bf30b8eba634fa34588e6aac3d1a9934 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 30 Dec 2020 15:10:19 -0800 Subject: [PATCH 05/17] new type system Signed-off-by: Ryan Nett --- .../org/tensorflow/variable/EagerVariable.java | 4 ++-- .../org/tensorflow/variable/GraphVariable.java | 8 ++++---- .../org/tensorflow/variable/MutableVariable.java | 9 +++++---- .../java/org/tensorflow/variable/Variable.java | 14 ++++++++------ 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java index e76d9c52c72..f1af8a33ce6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java @@ -16,17 +16,17 @@ */ package org.tensorflow.variable; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; class EagerVariable extends MutableVariable { private Operand value = null; - EagerVariable(Scope scope, Shape shape, DataType dtype) { + EagerVariable(Scope scope, Shape shape, DataType dtype) { super(scope, shape, dtype); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java index afb59915f62..e78d670ff8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java @@ -17,22 +17,22 @@ package org.tensorflow.variable; import java.util.Collections; -import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Init; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; class GraphVariable extends MutableVariable { - private final org.tensorflow.op.core.Variable variable; + private org.tensorflow.op.core.Variable variable; private Operand get = null; - GraphVariable(Scope scope, Shape shape, DataType dataType) { + GraphVariable(Scope scope, Shape shape, DataType dataType) { super(scope, shape, dataType); - variable = org.tensorflow.op.core.Variable.create(scope, shape, dataType); +// variable = org.tensorflow.op.core.Variable.create(scope, shape, dataType.); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index 41dd2db4346..f6cad7408d2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -18,17 +18,18 @@ import java.util.Arrays; import java.util.function.Supplier; -import org.tensorflow.DataType; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -41,13 +42,13 @@ public abstract class MutableVariable implements Variable { private final String name; protected final Shape shape; - protected final DataType dataType; + protected final DataType dataType; protected boolean hasInitialized = false; //TODO use the new resource API. - protected MutableVariable(Scope scope, Shape shape, DataType dataType){ + protected MutableVariable(Scope scope, Shape shape, DataType dataType){ this.initialScope = scope.withName(null); this.shape = shape; @@ -67,7 +68,7 @@ public Shape getShape() { /** * Get the variable's constant data type. */ - public DataType getDataType() { + public DataType getDataType() { return dataType; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 1ce73e9ca02..66dad72c4d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -17,15 +17,17 @@ package org.tensorflow.variable; import java.util.function.Supplier; -import org.tensorflow.DataType; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.family.TType; /** @@ -55,7 +57,7 @@ public interface Variable extends Operand { /** * Get the variable's constant data type. */ - DataType getDataType(); + DataType getDataType(); /** * Get whether the variable has had a value assigned to it. @@ -130,11 +132,11 @@ default Shape shape() { * @see Variable */ @Endpoint(name = "Variable") - public static Variable create(Scope scope, Shape shape, DataType dataType){ + public static Variable create(Scope scope, Shape shape, Class dataType){ if(scope.env().isEager()) { - return new EagerVariable<>(scope, shape, dataType); + return new EagerVariable<>(scope, shape, Operands.toDataType(dataType)); } else { - return new GraphVariable<>(scope, shape, dataType); + return new GraphVariable<>(scope, shape, Operands.toDataType(dataType)); } } @@ -152,7 +154,7 @@ public static Variable create(Scope scope, Shape shape, Dat */ @Endpoint(name = "Variable") public static Variable create(Scope scope, Operand initialValue){ - Variable variable = create(scope, initialValue.shape(), initialValue.asOutput().dataType()); + Variable variable = create(scope, initialValue.shape(), initialValue.type()); variable.initialize(variable); return variable; } From d4336de05abb913f7111c21788cf9e6b7e87e990 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 30 Dec 2020 19:29:33 -0800 Subject: [PATCH 06/17] Resource API version Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 6 +- .../tensorflow/variable/EagerVariable.java | 50 ---- .../tensorflow/variable/GraphVariable.java | 61 ----- .../tensorflow/variable/MutableVariable.java | 218 ++++++++++++------ .../org/tensorflow/variable/Variable.java | 41 ++-- .../java/org/tensorflow/VariableTest.java | 125 ++++++++++ 6 files changed, 309 insertions(+), 192 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 08125f84e6d..39f725a5085 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -423,7 +423,7 @@ public Variable Variable(Operand initialValue) { * @return a new {@link Variable} instance. * @see Variable */ - public Variable Variable(Shape shape, DataType dataType) { + public Variable Variable(Shape shape, Class dataType) { return Variable.create(scope, shape, dataType); } @@ -8003,8 +8003,8 @@ public org.tensorflow.op.core.Variable variable(Operand * @deprecated Use {@link org.tensorflow.op.Ops#Variable(Shape, DataType)} instead for a tf.Variable like API. */ @Deprecated - public org.tensorflow.op.core.Variable variable(Shape shape, - Class dtype, org.tensorflow.op.core.Variable.Options... options) { + public org.tensorflow.op.core.Variable variable(Shape shape, Class dtype, + org.tensorflow.op.core.Variable.Options... options) { return org.tensorflow.op.core.Variable.create(scope, shape, dtype, options); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java deleted file mode 100644 index f1af8a33ce6..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/EagerVariable.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - Copyright 2020 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ -package org.tensorflow.variable; - -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Scope; -import org.tensorflow.proto.framework.DataType; -import org.tensorflow.types.family.TType; - -class EagerVariable extends MutableVariable { - - private Operand value = null; - - EagerVariable(Scope scope, Shape shape, DataType dtype) { - super(scope, shape, dtype); - } - - @Override - protected Operand getValue() { - if(value == null){ - throw new IllegalStateException("Value has not been initialized."); - } - return value; - } - - @Override - protected void doInitialize(Scope scope, Operand value) { - this.value = value; - } - - @Override - protected void doAssign(Scope scope, Operand value) { - this.value = value; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java deleted file mode 100644 index e78d670ff8a..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/GraphVariable.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - Copyright 2020 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ -package org.tensorflow.variable; - -import java.util.Collections; -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Scope; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Init; -import org.tensorflow.proto.framework.DataType; -import org.tensorflow.types.family.TType; - -class GraphVariable extends MutableVariable { - - private org.tensorflow.op.core.Variable variable; - private Operand get = null; - - GraphVariable(Scope scope, Shape shape, DataType dataType) { - super(scope, shape, dataType); -// variable = org.tensorflow.op.core.Variable.create(scope, shape, dataType.); - } - - @Override - protected Operand getValue() { - if(get == null){ - throw new IllegalStateException("Variable has not been initialized."); - } - return get; - } - - @Override - protected void doInitialize(Scope scope, Operand value) { - Assign assignOp = Assign.create(scope, variable, value); - Init.add(scope, assignOp); - get = assignOp; - } - - @Override - protected void doAssign(Scope scope, Operand value) { - if(get != null){ - scope = scope.withControlDependencies(Collections.singletonList(get)); - } - - get = Assign.create(scope, variable, value); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index f6cad7408d2..daf0011b946 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -17,19 +17,23 @@ package org.tensorflow.variable; import java.util.Arrays; +import java.util.Collections; import java.util.function.Supplier; -import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Operands; -import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.AssignAddVariableOp; +import org.tensorflow.op.core.AssignSubVariableOp; +import org.tensorflow.op.core.AssignVariableOp; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.IsVariableInitialized; +import org.tensorflow.op.core.ReadVariableOp; +import org.tensorflow.op.core.VarHandleOp; +import org.tensorflow.op.core.VarHandleOp.Options; import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; /** @@ -37,122 +41,206 @@ * * @see Variable */ -public abstract class MutableVariable implements Variable { +public class MutableVariable implements Variable { + private final Scope initialScope; private final String name; - protected final Shape shape; - protected final DataType dataType; - - protected boolean hasInitialized = false; + private final Shape shape; + private final DataType dataType; + private final Class tType; + private final VarHandleOp handle; - //TODO use the new resource API. + private IsVariableInitialized isInitializedOp = null; + private Op initializationOp = null; + private ReadVariableOp cachedRead = null; + private Op lastAssign = null; - protected MutableVariable(Scope scope, Shape shape, DataType dataType){ - this.initialScope = scope.withName(null); + private boolean hasInitialized = false; + protected MutableVariable(Scope scope, Shape shape, Class dataType) { this.shape = shape; - this.dataType = dataType; + this.dataType = Operands.toDataType(dataType); + this.tType = dataType; + this.name = scope.makeOpName("Variable"); + + scope = scope.withName(null); + this.initialScope = scope.withSubScope(this.name); + + VarHandleOp.Options[] options; + + if (scope.env().isGraph()) { + options = new Options[]{VarHandleOp.sharedName(this.name)}; + } else { + options = new Options[0]; + } + + this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, options); + scope.env().registerVariable(this); } - /** - * Get the variable's constant shape. - * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. - */ + @Override public Shape getShape() { return shape; } - /** - * Get the variable's constant data type. - */ + @Override public DataType getDataType() { return dataType; } - /** - * Get whether the variable has had a value assigned to it. - */ + @Override public boolean isInitialized() { return hasInitialized; } - /** - * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. - */ + @Override public String getName() { return name; } - /** - * Get the current value of this variable. - */ - public Operand value(){ - if(!hasInitialized){ + @Override + public Operand value(Scope scope) { + if (!hasInitialized) { throw new IllegalStateException("Variable has not been initialized, can not get."); } - return getValue(); + if (cachedRead == null) { + if (lastAssign != null) { + scope = scope.withControlDependencies(Collections.singletonList(lastAssign)); + } + cachedRead = ReadVariableOp.create(scope, handle, tType); + } + return cachedRead; + } + + @Override + public Operand value() { + return value(initialScope); } - private void checkInput(Operand value){ - if(value.shape().isCompatibleWith(this.shape)){ + private void checkInput(Operand value) { + if (value.shape().isCompatibleWith(this.shape)) { throw new IllegalArgumentException("Shape of new value (" + value.shape() + ") is not compatible with the variable's shape (" + this.shape + ")."); } //TODO better checking w/ new types after refactor - if(value.asOutput().dataType() != dataType){ + if (value.asOutput().dataType() != dataType) { throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + ") is not compatible with the variable's data type (" + dataType + ")."); } } + @Override + public Op initialize(Operand value) { + if (hasInitialized) { + return initializationOp; + } + checkInput(value); + initializationOp = AssignVariableOp.create(initialScope, handle, value); + lastAssign = initializationOp; + Init.add(initialScope, lastAssign); + hasInitialized = true; + cachedRead = null; + return initializationOp; + } + + @Override + public Op initialize(Supplier> value) { + if (hasInitialized) { + return initializationOp; + } + return initialize(value.get()); + } + + @Override + public Operand isValueInitialized() { + if (isInitializedOp == null) { + isInitializedOp = IsVariableInitialized.create(initialScope, handle); + } + + return isInitializedOp; + } + /** - * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. - * @param value the value to initialize this variable with. - * @return the new value (or current if it was already initialized). + * Assign a new value to this variable using the given scope. + * + * @param value the value to assign. + * @see AssignVariableOp#create */ - public Operand initialize(Operand value){ - if(hasInitialized){ - return value(); - } + public Op assign(Scope scope, Operand value) { checkInput(value); - doInitialize(initialScope, value); + lastAssign = AssignVariableOp.create(scope, handle, value); hasInitialized = true; - return value(); + cachedRead = null; + return lastAssign; } + /** + * Assign a new value to this variable using the variable's scope. + * + * @param value the value to assign. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignVariableOp#create + */ + public Op assign(Operand value, Op... controlDependencies) { + return assign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } /** - * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. - *

- * The provided function will not be invoked if this function no-ops. - * @param value a function returning the value to initialize this variable with. - * Will only be called if initialization is done. - * @return the new value (or current if it was already initialized). + * Decrement the variable's value by the given value, using the given scope. + * + * @param value amount to decrease the variable's value by. + * @see AssignSubVariableOp#create */ - public Operand initialize(Supplier> value){ - if(hasInitialized){ - return value(); + public Op assignSub(Scope scope, Operand value) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not decrement."); } - return initialize(value.get()); + checkInput(value); + lastAssign = AssignSubVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; } /** - * Assign a new value to this variable. - * @param value the value to assign. + * Decrement the variable's value by the given value, using the variable's scope. + * + * @param value amount to decrease the variable's value by. * @param controlDependencies any control dependencies of the assignment. - * @return the new value + * @see AssignSubVariableOp#create */ - public Operand assign(Operand value, Op... controlDependencies){ + public Op assignSub(Operand value, Op... controlDependencies) { + return assignSub(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } + + /** + * Increment the variable's value by the given value, using the given scope. + * + * @param value amount to decrease the variable's value by. + * @see AssignAddVariableOp#create + */ + public Op assignAdd(Scope scope, Operand value) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not increment."); + } checkInput(value); - doAssign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + lastAssign = AssignAddVariableOp.create(scope, handle, value); hasInitialized = true; - return value(); + cachedRead = null; + return lastAssign; } - protected abstract Operand getValue(); - protected abstract void doInitialize(Scope scope, Operand value); - protected abstract void doAssign(Scope scope, Operand value); + /** + * Increment the variable's value by the given value, using the variable's scope. + * + * @param value amount to decrease the variable's value by. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignAddVariableOp#create + */ + public Op assignAdd(Operand value, Op... controlDependencies) { + return assignAdd(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 66dad72c4d0..884f7137b7e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -23,11 +23,13 @@ import org.tensorflow.Output; import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; import org.tensorflow.op.Operands; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; /** @@ -41,8 +43,6 @@ * The exposed value will not usually be a {@link org.tensorflow.op.core.Variable}. *

* Variables will be registered in their execution environment's {@link ExecutionEnvironment#variables()}. - *

- * Implemented by {@link MutableVariable}, which provides mutability. * * @see MutableVariable */ @@ -60,26 +60,39 @@ public interface Variable extends Operand { DataType getDataType(); /** - * Get whether the variable has had a value assigned to it. + * Get whether the variable has had a value assigned to it. This method relates to the Java object, not the graph variable. + *

+ * This operation returns true if {@code initialize} or {@code assign} methods have been used on + * the variable object, it does not provide any information about the state of the graph variable. + * For that, use {@link #isValueInitialized()} */ boolean isInitialized(); + /** + * Get whether the graph value is initialized. In eager mode, this will be the same as {@link #isInitialized()}. + */ + Operand isValueInitialized(); + /** * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. */ String getName(); /** - * Get the current value of this variable. + * Get the current value of this variable, using the variable's scope. */ Operand value(); + /** + * Get the current value of this variable, using the given scope. + */ + Operand value(Scope scope); + /** * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. * @param value the value to initialize this variable with. - * @return the new value (or current if it was already initialized). */ - Operand initialize(Operand value); + Op initialize(Operand value); /** * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. @@ -87,9 +100,8 @@ public interface Variable extends Operand { * The provided function will not be invoked if this function no-ops. * @param value a function returning the value to initialize this variable with. * Will only be called if initialization is done. - * @return the new value (or current if it was already initialized). */ - Operand initialize(Supplier> value); + Op initialize(Supplier> value); /** * Get the current value as an Output. @@ -120,6 +132,13 @@ default Shape shape() { } } + /** + * Get the underlying mutable variable. + */ + default MutableVariable asMutableVariable(){ + return (MutableVariable) this; + } + /** * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with * support for assignment and initialization that works in both eager and graph modes. @@ -133,11 +152,7 @@ default Shape shape() { */ @Endpoint(name = "Variable") public static Variable create(Scope scope, Shape shape, Class dataType){ - if(scope.env().isEager()) { - return new EagerVariable<>(scope, shape, Operands.toDataType(dataType)); - } else { - return new GraphVariable<>(scope, shape, Operands.toDataType(dataType)); - } + return new MutableVariable<>(scope, shape, dataType); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java new file mode 100644 index 00000000000..0c6a55d18a1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java @@ -0,0 +1,125 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.variable.MutableVariable; + +/** + * Unit tests for {@link org.tensorflow.variable.Variable}/{@link org.tensorflow.variable.MutableVariable} + */ +public class VariableTest { + + @Test + public void testEager() { + Ops tf = Ops.create(EagerSession.create()); + MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); + + assertFalse(variable.isInitialized()); + assertFalse(variable.isValueInitialized().asTensor().getBoolean(0)); + + variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertTrue(variable.isInitialized()); + assertTrue(variable.isValueInitialized().asTensor().getBoolean(0)); + + assertEquals(1, variable.value().asTensor().getFloat(0, 0)); + + variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); + + assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + + variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertEquals(4, variable.value().asTensor().getFloat(0, 0)); + + variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + } + + @Test + public void testGraph() { + Graph graph = new Graph(); + Ops tf = Ops.create(graph); + MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); + + assertFalse(variable.isInitialized()); + + variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); + + assertTrue(variable.isInitialized()); + + Operand original = variable.value(); + + Op assign = variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); + Operand afterAssign = variable.value(); + + Op increment = variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); + Operand afterIncrement = variable.value(); + + Op decrement = variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); + Operand afterDecrement = variable.value(); + + Session session = new Session(graph); + + assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); + + session.run(tf.init()); + + assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); + + // test control deps (in-run assign) + + assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + assertEquals(3, ((TFloat32) session.runner().fetch(afterAssign).run().get(0)).getFloat(0, 0)); + + assertEquals(4, ((TFloat32) session.runner().fetch(afterIncrement).run().get(0)).getFloat(0, 0)); + + assertEquals(3, ((TFloat32) session.runner().fetch(afterDecrement).run().get(0)).getFloat(0, 0)); + + // test persistence (multi-run assign) + + session.run(decrement); + session.run(decrement); + + assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + session.run(increment); + session.run(increment); + session.run(increment); + session.run(increment); + + assertEquals(5, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + session.run(assign); + + assertEquals(3, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + + } + +} From 3fc2613d4b0b0765dd58669c5b090696c4a6ce84 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 30 Dec 2020 20:21:42 -0800 Subject: [PATCH 07/17] forgot a not Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/variable/MutableVariable.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index daf0011b946..396f6f6f2e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -121,7 +121,7 @@ public Operand value() { } private void checkInput(Operand value) { - if (value.shape().isCompatibleWith(this.shape)) { + if (!value.shape().isCompatibleWith(this.shape)) { throw new IllegalArgumentException("Shape of new value (" + value.shape() + ") is not compatible with the variable's shape (" + this.shape + ")."); } From 8b789d95f2f65c053972c2d0593244637d8b063a Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 30 Dec 2020 20:22:47 -0800 Subject: [PATCH 08/17] forgot a not Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/variable/MutableVariable.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index 396f6f6f2e7..791ce4ec5b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -125,8 +125,8 @@ private void checkInput(Operand value) { throw new IllegalArgumentException("Shape of new value (" + value.shape() + ") is not compatible with the variable's shape (" + this.shape + ")."); } - //TODO better checking w/ new types after refactor - if (value.asOutput().dataType() != dataType) { + + if (!tType.isAssignableFrom(value.asOutput().type())) { throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + ") is not compatible with the variable's data type (" + dataType + ")."); } From 7e2dc39659faf82224f97fd0d601ecdc62676932 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 30 Dec 2020 20:37:52 -0800 Subject: [PATCH 09/17] don't add init in eager mode Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/variable/MutableVariable.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index 791ce4ec5b4..fb024adc2c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -139,8 +139,11 @@ public Op initialize(Operand value) { } checkInput(value); initializationOp = AssignVariableOp.create(initialScope, handle, value); - lastAssign = initializationOp; - Init.add(initialScope, lastAssign); + + //TODO this if will be unnecessary after the init PR + if(initialScope.env().isGraph()) + Init.add(initialScope, initializationOp); + hasInitialized = true; cachedRead = null; return initializationOp; From 4f8c54190bedc52c24f4b8b0886bef9ff315f56d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 31 Dec 2020 17:01:19 -0800 Subject: [PATCH 10/17] Add handle getter, fix data races Signed-off-by: Ryan Nett --- .../tensorflow/variable/MutableVariable.java | 26 ++++++++++++------- .../org/tensorflow/variable/Variable.java | 7 +++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index fb024adc2c9..b01b03d0056 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -49,6 +49,7 @@ public class MutableVariable implements Variable { private final Shape shape; private final DataType dataType; private final Class tType; + private final VarHandleOp handle; private IsVariableInitialized isInitializedOp = null; @@ -81,6 +82,11 @@ protected MutableVariable(Scope scope, Shape shape, Class dataType) { scope.env().registerVariable(this); } + @Override + public VarHandleOp getHandle() { + return handle; + } + @Override public Shape getShape() { return shape; @@ -102,17 +108,19 @@ public String getName() { } @Override - public Operand value(Scope scope) { + public synchronized Operand value(Scope scope) { if (!hasInitialized) { throw new IllegalStateException("Variable has not been initialized, can not get."); } - if (cachedRead == null) { + ReadVariableOp ret = cachedRead; + if (ret == null) { if (lastAssign != null) { scope = scope.withControlDependencies(Collections.singletonList(lastAssign)); } - cachedRead = ReadVariableOp.create(scope, handle, tType); + ret = ReadVariableOp.create(scope, handle, tType); } - return cachedRead; + cachedRead = ret; + return ret; } @Override @@ -133,7 +141,7 @@ private void checkInput(Operand value) { } @Override - public Op initialize(Operand value) { + public synchronized Op initialize(Operand value) { if (hasInitialized) { return initializationOp; } @@ -150,7 +158,7 @@ public Op initialize(Operand value) { } @Override - public Op initialize(Supplier> value) { + public synchronized Op initialize(Supplier> value) { if (hasInitialized) { return initializationOp; } @@ -172,7 +180,7 @@ public Operand isValueInitialized() { * @param value the value to assign. * @see AssignVariableOp#create */ - public Op assign(Scope scope, Operand value) { + public synchronized Op assign(Scope scope, Operand value) { checkInput(value); lastAssign = AssignVariableOp.create(scope, handle, value); hasInitialized = true; @@ -197,7 +205,7 @@ public Op assign(Operand value, Op... controlDependencies) { * @param value amount to decrease the variable's value by. * @see AssignSubVariableOp#create */ - public Op assignSub(Scope scope, Operand value) { + public synchronized Op assignSub(Scope scope, Operand value) { if (!hasInitialized) { throw new IllegalStateException("Variable has not been initialized, can not decrement."); } @@ -225,7 +233,7 @@ public Op assignSub(Operand value, Op... controlDependencies) { * @param value amount to decrease the variable's value by. * @see AssignAddVariableOp#create */ - public Op assignAdd(Scope scope, Operand value) { + public synchronized Op assignAdd(Scope scope, Operand value) { if (!hasInitialized) { throw new IllegalStateException("Variable has not been initialized, can not increment."); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 884f7137b7e..90ad1d66c82 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -28,6 +28,7 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.VarHandleOp; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; @@ -48,6 +49,12 @@ */ @Operator public interface Variable extends Operand { + + /** + * Get the variable handle operation. + */ + VarHandleOp getHandle(); + /** * Get the variable's constant shape. * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. From 2ae55ab1033446fb4735d76de3272ac5cbe5baea Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 1 Jan 2021 15:39:02 -0800 Subject: [PATCH 11/17] try-with-resources in test Signed-off-by: Ryan Nett --- .../java/org/tensorflow/VariableTest.java | 98 ++++++++++--------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java index 0c6a55d18a1..0a17511deae 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java @@ -35,90 +35,94 @@ public class VariableTest { @Test public void testEager() { - Ops tf = Ops.create(EagerSession.create()); - MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); + try(EagerSession es = EagerSession.create()) { + Ops tf = Ops.create(es); + MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); - assertFalse(variable.isInitialized()); - assertFalse(variable.isValueInitialized().asTensor().getBoolean(0)); + assertFalse(variable.isInitialized()); + assertFalse(variable.isValueInitialized().asTensor().getBoolean(0)); - variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); + variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); - assertTrue(variable.isInitialized()); - assertTrue(variable.isValueInitialized().asTensor().getBoolean(0)); + assertTrue(variable.isInitialized()); + assertTrue(variable.isValueInitialized().asTensor().getBoolean(0)); - assertEquals(1, variable.value().asTensor().getFloat(0, 0)); + assertEquals(1, variable.value().asTensor().getFloat(0, 0)); - variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); + variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); - assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + assertEquals(3, variable.value().asTensor().getFloat(0, 0)); - variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); + variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); - assertEquals(4, variable.value().asTensor().getFloat(0, 0)); + assertEquals(4, variable.value().asTensor().getFloat(0, 0)); - variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); + variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); - assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + assertEquals(3, variable.value().asTensor().getFloat(0, 0)); + } } @Test public void testGraph() { - Graph graph = new Graph(); - Ops tf = Ops.create(graph); - MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); + try(Graph graph = new Graph()) { + Ops tf = Ops.create(graph); + MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); - assertFalse(variable.isInitialized()); + assertFalse(variable.isInitialized()); - variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); + variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); - assertTrue(variable.isInitialized()); + assertTrue(variable.isInitialized()); - Operand original = variable.value(); + Operand original = variable.value(); - Op assign = variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); - Operand afterAssign = variable.value(); + Op assign = variable.assign(tf.math.add(tf.ones(tf.array(10, 10), TFloat32.class), tf.constant(2f))); + Operand afterAssign = variable.value(); - Op increment = variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); - Operand afterIncrement = variable.value(); + Op increment = variable.assignAdd(tf.ones(tf.array(10, 10), TFloat32.class)); + Operand afterIncrement = variable.value(); - Op decrement = variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); - Operand afterDecrement = variable.value(); + Op decrement = variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); + Operand afterDecrement = variable.value(); - Session session = new Session(graph); + try(Session session = new Session(graph)) { - assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); + assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); - session.run(tf.init()); + session.run(tf.init()); - assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); + assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); - // test control deps (in-run assign) + // test control deps (in-run assign) - assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); - assertEquals(3, ((TFloat32) session.runner().fetch(afterAssign).run().get(0)).getFloat(0, 0)); + assertEquals(3, ((TFloat32) session.runner().fetch(afterAssign).run().get(0)).getFloat(0, 0)); - assertEquals(4, ((TFloat32) session.runner().fetch(afterIncrement).run().get(0)).getFloat(0, 0)); + assertEquals(4, ((TFloat32) session.runner().fetch(afterIncrement).run().get(0)).getFloat(0, 0)); - assertEquals(3, ((TFloat32) session.runner().fetch(afterDecrement).run().get(0)).getFloat(0, 0)); + assertEquals(3, ((TFloat32) session.runner().fetch(afterDecrement).run().get(0)).getFloat(0, 0)); - // test persistence (multi-run assign) + // test persistence (multi-run assign) - session.run(decrement); - session.run(decrement); + session.run(decrement); + session.run(decrement); - assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + assertEquals(1, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); - session.run(increment); - session.run(increment); - session.run(increment); - session.run(increment); + session.run(increment); + session.run(increment); + session.run(increment); + session.run(increment); - assertEquals(5, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + assertEquals(5, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); - session.run(assign); + session.run(assign); - assertEquals(3, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + assertEquals(3, ((TFloat32) session.runner().fetch(original).run().get(0)).getFloat(0, 0)); + } + } } From 49e154785456d89fa5c6c516b6df007ac138d59f Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 2 Jan 2021 17:04:45 -0800 Subject: [PATCH 12/17] typo fix Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/variable/Variable.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 90ad1d66c82..63b125e2b18 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -177,7 +177,7 @@ public static Variable create(Scope scope, Shape shape, Cla @Endpoint(name = "Variable") public static Variable create(Scope scope, Operand initialValue){ Variable variable = create(scope, initialValue.shape(), initialValue.type()); - variable.initialize(variable); + variable.initialize(initialValue); return variable; } } From 1309ee594c8ccef092f9f7d75b9a0ece78d832ba Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 2 Jan 2021 18:00:10 -0800 Subject: [PATCH 13/17] make assignAdd and assignSub have control deps on any cached reads, to match python Signed-off-by: Ryan Nett --- .../java/org/tensorflow/variable/MutableVariable.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index b01b03d0056..a5d1dab35f0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -209,6 +209,10 @@ public synchronized Op assignSub(Scope scope, Operand value) { if (!hasInitialized) { throw new IllegalStateException("Variable has not been initialized, can not decrement."); } + + if(cachedRead != null) + scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); + checkInput(value); lastAssign = AssignSubVariableOp.create(scope, handle, value); hasInitialized = true; @@ -237,6 +241,10 @@ public synchronized Op assignAdd(Scope scope, Operand value) { if (!hasInitialized) { throw new IllegalStateException("Variable has not been initialized, can not increment."); } + + if(cachedRead != null) + scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); + checkInput(value); lastAssign = AssignAddVariableOp.create(scope, handle, value); hasInitialized = true; From 8d3ea0cfaca78e464792b607352ec63ccf5f4547 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 2 Jan 2021 18:01:11 -0800 Subject: [PATCH 14/17] Add gradient test, ignore and note C++ issue Signed-off-by: Ryan Nett --- .../java/org/tensorflow/VariableTest.java | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java index 0a17511deae..30ff5369cec 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java @@ -18,15 +18,21 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import org.junit.Ignore; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Gradients; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; import org.tensorflow.variable.MutableVariable; +import org.tensorflow.variable.Variable; /** * Unit tests for {@link org.tensorflow.variable.Variable}/{@link org.tensorflow.variable.MutableVariable} @@ -35,7 +41,7 @@ public class VariableTest { @Test public void testEager() { - try(EagerSession es = EagerSession.create()) { + try (EagerSession es = EagerSession.create()) { Ops tf = Ops.create(es); MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); @@ -65,7 +71,7 @@ public void testEager() { @Test public void testGraph() { - try(Graph graph = new Graph()) { + try (Graph graph = new Graph()) { Ops tf = Ops.create(graph); MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); @@ -86,7 +92,7 @@ public void testGraph() { Op decrement = variable.assignSub(tf.ones(tf.array(10, 10), TFloat32.class)); Operand afterDecrement = variable.value(); - try(Session session = new Session(graph)) { + try (Session session = new Session(graph)) { assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); @@ -126,4 +132,34 @@ public void testGraph() { } + @Test + @Ignore // gradient not specified at c++ level: https://github.com/tensorflow/tensorflow/issues/46114. + public void gradientTest() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Ops tf = Ops.create(g); + + Variable variable = tf.Variable(tf.placeholder(TFloat32.class)); + + Output y0 = tf.math.square(variable.value()).y(); + Output y1 = tf.math.square(tf.math.square(variable.value())).y(); + + Output x = variable.getHandle().asOutput(); + + Gradients grads = Gradients.create(tf.scope(), Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (TFloat32 c = TFloat32.scalarOf(3.0f); + AutoCloseableList outputs = + new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + + //TODO expected value may be wrong, check once C++ gradient exists + assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + } + } + } + } From 3b0a6f5aae18558fb9a667fcb00206ee31b7c662 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 2 Jan 2021 19:53:51 -0800 Subject: [PATCH 15/17] Add variable options, with trainable (not used yet) Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 12 ++++--- .../tensorflow/variable/MutableVariable.java | 29 ++++++++++++--- .../org/tensorflow/variable/Variable.java | 36 ++++++++++++++++--- 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 39f725a5085..f08e4da08d3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -404,11 +404,13 @@ private Ops(Scope scope) { * * @param scope * @param initialValue the initial value of the variable. + * @param options carries optional attributes values * @return a new {@link Variable} instance. * @see Variable */ - public Variable Variable(Operand initialValue) { - return Variable.create(scope, initialValue); + public Variable Variable(Operand initialValue, + Variable.Options... options) { + return Variable.create(scope, initialValue, options); } /** @@ -420,11 +422,13 @@ public Variable Variable(Operand initialValue) { * @param scope * @param shape the static shape of the variable. * @param dataType the data type of the variable. + * @param options carries optional attributes values * @return a new {@link Variable} instance. * @see Variable */ - public Variable Variable(Shape shape, Class dataType) { - return Variable.create(scope, shape, dataType); + public Variable Variable(Shape shape, Class dataType, + Variable.Options... options) { + return Variable.create(scope, shape, dataType, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java index a5d1dab35f0..6ce801e4834 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java @@ -48,6 +48,7 @@ public class MutableVariable implements Variable { private final Shape shape; private final DataType dataType; + private final boolean trainable; private final Class tType; private final VarHandleOp handle; @@ -59,29 +60,47 @@ public class MutableVariable implements Variable { private boolean hasInitialized = false; - protected MutableVariable(Scope scope, Shape shape, Class dataType) { + protected MutableVariable(Scope scope, Shape shape, Class dataType, + Options[] options) { this.shape = shape; this.dataType = Operands.toDataType(dataType); this.tType = dataType; + boolean trainable = true; + if (options != null) { + for (Options opts : options) { + if (opts.trainable != null) { + trainable = opts.trainable; + } + } + } + this.trainable = trainable; + this.name = scope.makeOpName("Variable"); scope = scope.withName(null); this.initialScope = scope.withSubScope(this.name); - VarHandleOp.Options[] options; + + + VarHandleOp.Options[] handleOptions; if (scope.env().isGraph()) { - options = new Options[]{VarHandleOp.sharedName(this.name)}; + handleOptions = new VarHandleOp.Options[]{VarHandleOp.sharedName(this.name)}; } else { - options = new Options[0]; + handleOptions = new VarHandleOp.Options[0]; } - this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, options); + this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, handleOptions); scope.env().registerVariable(this); } + @Override + public boolean isTrainable() { + return trainable; + } + @Override public VarHandleOp getHandle() { return handle; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 63b125e2b18..e0320a0466d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -29,6 +29,7 @@ import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.VarHandleOp; +import org.tensorflow.op.core.Variable.Options; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; @@ -50,6 +51,11 @@ @Operator public interface Variable extends Operand { + /** + * Get whether the variable is trainable (whether it should be updated by optimizers). + */ + boolean isTrainable(); + /** * Get the variable handle operation. */ @@ -146,6 +152,26 @@ default MutableVariable asMutableVariable(){ return (MutableVariable) this; } + /** + * Optional attributes for {@link org.tensorflow.op.core.Variable} + */ + public static class Options { + + /** + * @param trainable If non-null, this variable's trainability is as given. + * Otherwise, it is trainable. + */ + public Options trainable(boolean trainable) { + this.trainable = trainable; + return this; + } + + Boolean trainable = null; + + private Options() { + } + } + /** * Create a new {@link Variable} object, representing a mutable tensor value with constant shape and data type, with * support for assignment and initialization that works in both eager and graph modes. @@ -154,12 +180,13 @@ default MutableVariable asMutableVariable(){ * @param scope * @param shape the static shape of the variable. * @param dataType the data type of the variable. + * @param options carries optional attributes values * @return a new {@link Variable} instance. * @see Variable */ @Endpoint(name = "Variable") - public static Variable create(Scope scope, Shape shape, Class dataType){ - return new MutableVariable<>(scope, shape, dataType); + public static Variable create(Scope scope, Shape shape, Class dataType, Options... options){ + return new MutableVariable<>(scope, shape, dataType, options); } /** @@ -171,12 +198,13 @@ public static Variable create(Scope scope, Shape shape, Cla * The name can be set using {@link org.tensorflow.op.Ops#withName(String)} just like any other op. * @param scope * @param initialValue the initial value of the variable. + * @param options carries optional attributes values * @return a new {@link Variable} instance. * @see Variable */ @Endpoint(name = "Variable") - public static Variable create(Scope scope, Operand initialValue){ - Variable variable = create(scope, initialValue.shape(), initialValue.type()); + public static Variable create(Scope scope, Operand initialValue, Options... options){ + Variable variable = create(scope, initialValue.shape(), initialValue.type(), options); variable.initialize(initialValue); return variable; } From 6294f8119db3baf097e1cacc8b6c81b0e7c7cf94 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 2 Jan 2021 20:55:03 -0800 Subject: [PATCH 16/17] Merge Variable and MutableVariable, un-deprecate registerVariable Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerSession.java | 8 +- .../org/tensorflow/ExecutionEnvironment.java | 10 +- .../src/main/java/org/tensorflow/Graph.java | 8 +- .../tensorflow/variable/MutableVariable.java | 284 --------------- .../org/tensorflow/variable/Variable.java | 323 ++++++++++++++---- .../java/org/tensorflow/VariableTest.java | 7 +- 6 files changed, 278 insertions(+), 362 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 4c5ac08b519..cff25316e0d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -36,7 +36,7 @@ import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; -import org.tensorflow.variable.MutableVariable; +import org.tensorflow.variable.Variable; /** * An environment for executing TensorFlow operations eagerly. @@ -361,15 +361,15 @@ void detach(Pointer... resources) { } } - private final Map> variables = new LinkedHashMap<>(); + private final Map> variables = new LinkedHashMap<>(); @Override - public void registerVariable(MutableVariable variable) { + public void registerVariable(Variable variable) { variables.put(variable.getName(), variable); } @Override - public Map> variables() { + public Map> variables() { return Collections.unmodifiableMap(variables); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index 0cfcb4fb598..665b3688944 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -17,7 +17,7 @@ import org.tensorflow.op.Op; import java.util.Map; -import org.tensorflow.variable.MutableVariable; +import org.tensorflow.variable.Variable; /** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ public interface ExecutionEnvironment { @@ -63,14 +63,12 @@ default boolean isOpEnabled(String opType) { */ Types environmentType(); - Map> variables(); + Map> variables(); /** - * Registers a variable with this execution environment. - * @deprecated Done automatically in Variable's constructor, should only be used internally. + * Registers a variable with this execution environment. For internal use only. */ - @Deprecated - void registerVariable(MutableVariable variable); + void registerVariable(Variable variable); default boolean isEager() { return environmentType() == Types.EAGER; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index e3e5dd058a6..6dcb4820eed 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -59,7 +59,7 @@ import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; -import org.tensorflow.variable.MutableVariable; +import org.tensorflow.variable.Variable; /** @@ -481,15 +481,15 @@ synchronized SaverDef saverDef() { private final List initializers = new ArrayList<>(); - private final Map> variables = new LinkedHashMap<>(); + private final Map> variables = new LinkedHashMap<>(); @Override - public void registerVariable(MutableVariable variable) { + public void registerVariable(Variable variable) { variables.put(variable.getName(), variable); } @Override - public Map> variables() { + public Map> variables() { return Collections.unmodifiableMap(variables); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java deleted file mode 100644 index 6ce801e4834..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/MutableVariable.java +++ /dev/null @@ -1,284 +0,0 @@ -/* - Copyright 2020 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ -package org.tensorflow.variable; - -import java.util.Arrays; -import java.util.Collections; -import java.util.function.Supplier; -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Operands; -import org.tensorflow.op.Scope; -import org.tensorflow.op.core.AssignAddVariableOp; -import org.tensorflow.op.core.AssignSubVariableOp; -import org.tensorflow.op.core.AssignVariableOp; -import org.tensorflow.op.core.Init; -import org.tensorflow.op.core.IsVariableInitialized; -import org.tensorflow.op.core.ReadVariableOp; -import org.tensorflow.op.core.VarHandleOp; -import org.tensorflow.op.core.VarHandleOp.Options; -import org.tensorflow.proto.framework.DataType; -import org.tensorflow.types.TBool; -import org.tensorflow.types.family.TType; - -/** - * The implementation of {@link Variable}, with mutation methods. - * - * @see Variable - */ -public class MutableVariable implements Variable { - - private final Scope initialScope; - private final String name; - - private final Shape shape; - private final DataType dataType; - private final boolean trainable; - private final Class tType; - - private final VarHandleOp handle; - - private IsVariableInitialized isInitializedOp = null; - private Op initializationOp = null; - private ReadVariableOp cachedRead = null; - private Op lastAssign = null; - - private boolean hasInitialized = false; - - protected MutableVariable(Scope scope, Shape shape, Class dataType, - Options[] options) { - this.shape = shape; - this.dataType = Operands.toDataType(dataType); - this.tType = dataType; - - boolean trainable = true; - if (options != null) { - for (Options opts : options) { - if (opts.trainable != null) { - trainable = opts.trainable; - } - } - } - this.trainable = trainable; - - this.name = scope.makeOpName("Variable"); - - scope = scope.withName(null); - this.initialScope = scope.withSubScope(this.name); - - - - VarHandleOp.Options[] handleOptions; - - if (scope.env().isGraph()) { - handleOptions = new VarHandleOp.Options[]{VarHandleOp.sharedName(this.name)}; - } else { - handleOptions = new VarHandleOp.Options[0]; - } - - this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, handleOptions); - - scope.env().registerVariable(this); - } - - @Override - public boolean isTrainable() { - return trainable; - } - - @Override - public VarHandleOp getHandle() { - return handle; - } - - @Override - public Shape getShape() { - return shape; - } - - @Override - public DataType getDataType() { - return dataType; - } - - @Override - public boolean isInitialized() { - return hasInitialized; - } - - @Override - public String getName() { - return name; - } - - @Override - public synchronized Operand value(Scope scope) { - if (!hasInitialized) { - throw new IllegalStateException("Variable has not been initialized, can not get."); - } - ReadVariableOp ret = cachedRead; - if (ret == null) { - if (lastAssign != null) { - scope = scope.withControlDependencies(Collections.singletonList(lastAssign)); - } - ret = ReadVariableOp.create(scope, handle, tType); - } - cachedRead = ret; - return ret; - } - - @Override - public Operand value() { - return value(initialScope); - } - - private void checkInput(Operand value) { - if (!value.shape().isCompatibleWith(this.shape)) { - throw new IllegalArgumentException("Shape of new value (" + value.shape() + - ") is not compatible with the variable's shape (" + this.shape + ")."); - } - - if (!tType.isAssignableFrom(value.asOutput().type())) { - throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + - ") is not compatible with the variable's data type (" + dataType + ")."); - } - } - - @Override - public synchronized Op initialize(Operand value) { - if (hasInitialized) { - return initializationOp; - } - checkInput(value); - initializationOp = AssignVariableOp.create(initialScope, handle, value); - - //TODO this if will be unnecessary after the init PR - if(initialScope.env().isGraph()) - Init.add(initialScope, initializationOp); - - hasInitialized = true; - cachedRead = null; - return initializationOp; - } - - @Override - public synchronized Op initialize(Supplier> value) { - if (hasInitialized) { - return initializationOp; - } - return initialize(value.get()); - } - - @Override - public Operand isValueInitialized() { - if (isInitializedOp == null) { - isInitializedOp = IsVariableInitialized.create(initialScope, handle); - } - - return isInitializedOp; - } - - /** - * Assign a new value to this variable using the given scope. - * - * @param value the value to assign. - * @see AssignVariableOp#create - */ - public synchronized Op assign(Scope scope, Operand value) { - checkInput(value); - lastAssign = AssignVariableOp.create(scope, handle, value); - hasInitialized = true; - cachedRead = null; - return lastAssign; - } - - /** - * Assign a new value to this variable using the variable's scope. - * - * @param value the value to assign. - * @param controlDependencies any control dependencies of the assignment. - * @see AssignVariableOp#create - */ - public Op assign(Operand value, Op... controlDependencies) { - return assign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); - } - - /** - * Decrement the variable's value by the given value, using the given scope. - * - * @param value amount to decrease the variable's value by. - * @see AssignSubVariableOp#create - */ - public synchronized Op assignSub(Scope scope, Operand value) { - if (!hasInitialized) { - throw new IllegalStateException("Variable has not been initialized, can not decrement."); - } - - if(cachedRead != null) - scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); - - checkInput(value); - lastAssign = AssignSubVariableOp.create(scope, handle, value); - hasInitialized = true; - cachedRead = null; - return lastAssign; - } - - /** - * Decrement the variable's value by the given value, using the variable's scope. - * - * @param value amount to decrease the variable's value by. - * @param controlDependencies any control dependencies of the assignment. - * @see AssignSubVariableOp#create - */ - public Op assignSub(Operand value, Op... controlDependencies) { - return assignSub(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); - } - - /** - * Increment the variable's value by the given value, using the given scope. - * - * @param value amount to decrease the variable's value by. - * @see AssignAddVariableOp#create - */ - public synchronized Op assignAdd(Scope scope, Operand value) { - if (!hasInitialized) { - throw new IllegalStateException("Variable has not been initialized, can not increment."); - } - - if(cachedRead != null) - scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); - - checkInput(value); - lastAssign = AssignAddVariableOp.create(scope, handle, value); - hasInitialized = true; - cachedRead = null; - return lastAssign; - } - - /** - * Increment the variable's value by the given value, using the variable's scope. - * - * @param value amount to decrease the variable's value by. - * @param controlDependencies any control dependencies of the assignment. - * @see AssignAddVariableOp#create - */ - public Op assignAdd(Operand value, Op... controlDependencies) { - return assignAdd(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index e0320a0466d..7bd1b798773 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -16,20 +16,26 @@ */ package org.tensorflow.variable; +import java.util.Arrays; +import java.util.Collections; import java.util.function.Supplier; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; -import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Operands; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.AssignAddVariableOp; +import org.tensorflow.op.core.AssignSubVariableOp; +import org.tensorflow.op.core.AssignVariableOp; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.IsVariableInitialized; +import org.tensorflow.op.core.ReadVariableOp; import org.tensorflow.op.core.VarHandleOp; -import org.tensorflow.op.core.Variable.Options; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; @@ -42,85 +48,118 @@ *

* Provides methods to get the value and initialize the value if it hasn't already been set. * Also implements {@code Operand} using the stored value. - * The exposed value will not usually be a {@link org.tensorflow.op.core.Variable}. *

* Variables will be registered in their execution environment's {@link ExecutionEnvironment#variables()}. * - * @see MutableVariable */ @Operator -public interface Variable extends Operand { +public class Variable implements Operand { - /** - * Get whether the variable is trainable (whether it should be updated by optimizers). - */ - boolean isTrainable(); + private final Scope initialScope; + private final String name; - /** - * Get the variable handle operation. - */ - VarHandleOp getHandle(); + private final Shape shape; + private final DataType dataType; + private final boolean trainable; + private final Class tType; - /** - * Get the variable's constant shape. - * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. - */ - Shape getShape(); + private final VarHandleOp handle; - /** - * Get the variable's constant data type. - */ - DataType getDataType(); + private IsVariableInitialized isInitializedOp = null; + private Op initializationOp = null; + private ReadVariableOp cachedRead = null; + private Op lastAssign = null; + + private boolean hasInitialized = false; + + protected Variable(Scope scope, Shape shape, Class dataType, + Options[] options) { + this.shape = shape; + this.dataType = Operands.toDataType(dataType); + this.tType = dataType; + + boolean trainable = true; + if (options != null) { + for (Options opts : options) { + if (opts.trainable != null) { + trainable = opts.trainable; + } + } + } + this.trainable = trainable; + + this.name = scope.makeOpName("Variable"); + + scope = scope.withName(null); + this.initialScope = scope.withSubScope(this.name); - /** - * Get whether the variable has had a value assigned to it. This method relates to the Java object, not the graph variable. - *

- * This operation returns true if {@code initialize} or {@code assign} methods have been used on - * the variable object, it does not provide any information about the state of the graph variable. - * For that, use {@link #isValueInitialized()} - */ - boolean isInitialized(); + + + VarHandleOp.Options[] handleOptions; + + if (scope.env().isGraph()) { + handleOptions = new VarHandleOp.Options[]{VarHandleOp.sharedName(this.name)}; + } else { + handleOptions = new VarHandleOp.Options[0]; + } + + this.handle = VarHandleOp.create(initialScope.withName(name), dataType, shape, handleOptions); + + scope.env().registerVariable(this); + } /** - * Get whether the graph value is initialized. In eager mode, this will be the same as {@link #isInitialized()}. - */ - Operand isValueInitialized(); + * Get whether the variable is trainable (whether it should be updated by optimizers). + */ + public boolean isTrainable() { + return trainable; + } /** - * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. - */ - String getName(); + * Get the variable handle operation. + */ + public VarHandleOp getHandle() { + return handle; + } /** - * Get the current value of this variable, using the variable's scope. - */ - Operand value(); + * Get the variable's constant shape. + * This may have unknown dimensions, which do not impose a requirement on the value's dimensions. + */ + public Shape getShape() { + return shape; + } /** - * Get the current value of this variable, using the given scope. - */ - Operand value(Scope scope); + * Get the variable's constant data type. + */ + public DataType getDataType() { + return dataType; + } /** - * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. - * @param value the value to initialize this variable with. - */ - Op initialize(Operand value); + * Get whether the variable has had a value assigned to it. This method relates to the Java object, not the graph variable. + *

+ * This operation returns true if {@code initialize} or {@code assign} methods have been used on + * the variable object, it does not provide any information about the state of the graph variable. + * For that, use {@link #isValueInitialized()} + */ + public boolean isInitialized() { + return hasInitialized; + } /** - * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. - *

- * The provided function will not be invoked if this function no-ops. - * @param value a function returning the value to initialize this variable with. - * Will only be called if initialization is done. - */ - Op initialize(Supplier> value); + * Get the name of the variable, set using {@link org.tensorflow.op.Ops#withName(String)} the same as any other op. + */ + public String getName() { + return name; + } /** * Get the current value as an Output. */ @Override - default Output asOutput() { + public Output asOutput() { return value().asOutput(); } @@ -128,7 +167,7 @@ default Output asOutput() { * Get the op of the current value. */ @Override - default Operation op() { + public Operation op() { return value().op(); } @@ -137,7 +176,7 @@ default Operation op() { * in which case they will be filled in from the current value. */ @Override - default Shape shape() { + public Shape shape() { if(isInitialized()) { return value().shape(); } else { @@ -146,10 +185,174 @@ default Shape shape() { } /** - * Get the underlying mutable variable. + * Get the current value of this variable, using the given scope. + */ + public synchronized Operand value(Scope scope) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not get."); + } + ReadVariableOp ret = cachedRead; + if (ret == null) { + if (lastAssign != null) { + scope = scope.withControlDependencies(Collections.singletonList(lastAssign)); + } + ret = ReadVariableOp.create(scope, handle, tType); + } + cachedRead = ret; + return ret; + } + + /** + * Get the current value of this variable, using the variable's scope. + */ + public Operand value() { + return value(initialScope); + } + + private void checkInput(Operand value) { + if (!value.shape().isCompatibleWith(this.shape)) { + throw new IllegalArgumentException("Shape of new value (" + value.shape() + + ") is not compatible with the variable's shape (" + this.shape + ")."); + } + + if (!tType.isAssignableFrom(value.asOutput().type())) { + throw new IllegalArgumentException("Data type of new value (" + value.asOutput().dataType() + + ") is not compatible with the variable's data type (" + dataType + ")."); + } + } + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + * @param value the value to initialize this variable with. + */ + public synchronized Op initialize(Operand value) { + if (hasInitialized) { + return initializationOp; + } + checkInput(value); + initializationOp = AssignVariableOp.create(initialScope, handle, value); + + //TODO this if will be unnecessary after the init PR + if(initialScope.env().isGraph()) + Init.add(initialScope, initializationOp); + + hasInitialized = true; + cachedRead = null; + return initializationOp; + } + + /** + * Initialize this variable with a value, if it hasn't already been assigned a value. No-op if it has. + *

+ * The provided function will not be invoked if this function no-ops. + * @param value a function returning the value to initialize this variable with. + * Will only be called if initialization is done. + */ + public synchronized Op initialize(Supplier> value) { + if (hasInitialized) { + return initializationOp; + } + return initialize(value.get()); + } + + /** + * Get whether the graph value is initialized. In eager mode, this will be the same as {@link #isInitialized()}. + */ + public Operand isValueInitialized() { + if (isInitializedOp == null) { + isInitializedOp = IsVariableInitialized.create(initialScope, handle); + } + + return isInitializedOp; + } + + /** + * Assign a new value to this variable using the given scope. + * + * @param value the value to assign. + * @see AssignVariableOp#create + */ + public synchronized Op assign(Scope scope, Operand value) { + checkInput(value); + lastAssign = AssignVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; + } + + /** + * Assign a new value to this variable using the variable's scope. + * + * @param value the value to assign. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignVariableOp#create + */ + public Op assign(Operand value, Op... controlDependencies) { + return assign(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } + + /** + * Decrement the variable's value by the given value, using the given scope. + * + * @param value amount to decrease the variable's value by. + * @see AssignSubVariableOp#create + */ + public synchronized Op assignSub(Scope scope, Operand value) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not decrement."); + } + + if(cachedRead != null) + scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); + + checkInput(value); + lastAssign = AssignSubVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; + } + + /** + * Decrement the variable's value by the given value, using the variable's scope. + * + * @param value amount to decrease the variable's value by. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignSubVariableOp#create + */ + public Op assignSub(Operand value, Op... controlDependencies) { + return assignSub(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); + } + + /** + * Increment the variable's value by the given value, using the given scope. + * + * @param value amount to decrease the variable's value by. + * @see AssignAddVariableOp#create + */ + public synchronized Op assignAdd(Scope scope, Operand value) { + if (!hasInitialized) { + throw new IllegalStateException("Variable has not been initialized, can not increment."); + } + + if(cachedRead != null) + scope = scope.withControlDependencies(Collections.singletonList(cachedRead)); + + checkInput(value); + lastAssign = AssignAddVariableOp.create(scope, handle, value); + hasInitialized = true; + cachedRead = null; + return lastAssign; + } + + /** + * Increment the variable's value by the given value, using the variable's scope. + * + * @param value amount to decrease the variable's value by. + * @param controlDependencies any control dependencies of the assignment. + * @see AssignAddVariableOp#create */ - default MutableVariable asMutableVariable(){ - return (MutableVariable) this; + public Op assignAdd(Operand value, Op... controlDependencies) { + return assignAdd(initialScope.withControlDependencies(Arrays.asList(controlDependencies)), value); } /** @@ -186,7 +389,7 @@ private Options() { */ @Endpoint(name = "Variable") public static Variable create(Scope scope, Shape shape, Class dataType, Options... options){ - return new MutableVariable<>(scope, shape, dataType, options); + return new Variable<>(scope, shape, dataType, options); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java index 30ff5369cec..95a39a0ddb5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java @@ -31,11 +31,10 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import org.tensorflow.variable.MutableVariable; import org.tensorflow.variable.Variable; /** - * Unit tests for {@link org.tensorflow.variable.Variable}/{@link org.tensorflow.variable.MutableVariable} + * Unit tests for {@link org.tensorflow.variable.Variable}/{@link Variable} */ public class VariableTest { @@ -43,7 +42,7 @@ public class VariableTest { public void testEager() { try (EagerSession es = EagerSession.create()) { Ops tf = Ops.create(es); - MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); + Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class); assertFalse(variable.isInitialized()); assertFalse(variable.isValueInitialized().asTensor().getBoolean(0)); @@ -73,7 +72,7 @@ public void testEager() { public void testGraph() { try (Graph graph = new Graph()) { Ops tf = Ops.create(graph); - MutableVariable variable = tf.Variable(Shape.of(10, 10), TFloat32.class).asMutableVariable(); + Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class); assertFalse(variable.isInitialized()); From 342744ffb0f42d6ddaf4395ad6792791fced7958 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 12 Mar 2021 12:54:36 -0800 Subject: [PATCH 17/17] Some test fixes - gradients work. Control deps are broken and will be fixed in #237 PR. Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerSession.java | 3 +-- .../org/tensorflow/variable/Variable.java | 13 +++++----- .../java/org/tensorflow/VariableTest.java | 26 +++++++++++-------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index cff25316e0d..03618323f65 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -34,7 +34,6 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.variable.Variable; @@ -294,7 +293,7 @@ public Types environmentType() { @Override public boolean isOpEnabled(String opType) { switch (opType) { - case Variable.OP_NAME: + case org.tensorflow.op.core.Variable.OP_NAME: case Placeholder.OP_NAME: case Assign.OP_NAME: return false; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java index 7bd1b798773..ad19909cd6c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/variable/Variable.java @@ -26,6 +26,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Operands; +import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -36,6 +37,7 @@ import org.tensorflow.op.core.IsVariableInitialized; import org.tensorflow.op.core.ReadVariableOp; import org.tensorflow.op.core.VarHandleOp; +import org.tensorflow.op.core.VarIsInitializedOp; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TType; @@ -65,7 +67,7 @@ public class Variable implements Operand { private final VarHandleOp handle; - private IsVariableInitialized isInitializedOp = null; + private VarIsInitializedOp isInitializedOp = null; private Op initializationOp = null; private ReadVariableOp cachedRead = null; private Op lastAssign = null; @@ -192,7 +194,7 @@ public synchronized Operand value(Scope scope) { throw new IllegalStateException("Variable has not been initialized, can not get."); } ReadVariableOp ret = cachedRead; - if (ret == null) { + if (ret == null || scope.env().isEager()) { if (lastAssign != null) { scope = scope.withControlDependencies(Collections.singletonList(lastAssign)); } @@ -233,8 +235,7 @@ public synchronized Op initialize(Operand value) { initializationOp = AssignVariableOp.create(initialScope, handle, value); //TODO this if will be unnecessary after the init PR - if(initialScope.env().isGraph()) - Init.add(initialScope, initializationOp); + Init.add(initialScope, initializationOp); hasInitialized = true; cachedRead = null; @@ -259,8 +260,8 @@ public synchronized Op initialize(Supplier> value) { * Get whether the graph value is initialized. In eager mode, this will be the same as {@link #isInitialized()}. */ public Operand isValueInitialized() { - if (isInitializedOp == null) { - isInitializedOp = IsVariableInitialized.create(initialScope, handle); + if (isInitializedOp == null || initializationOp.env().isEager()) { + isInitializedOp = VarIsInitializedOp.create(initialScope, handle); } return isInitializedOp; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java index 95a39a0ddb5..bb73db88775 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/VariableTest.java @@ -28,6 +28,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Gradients; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -45,12 +46,12 @@ public void testEager() { Variable variable = tf.Variable(Shape.of(10, 10), TFloat32.class); assertFalse(variable.isInitialized()); - assertFalse(variable.isValueInitialized().asTensor().getBoolean(0)); + assertFalse(variable.isValueInitialized().asTensor().getBoolean()); variable.initialize(tf.ones(tf.array(10, 10), TFloat32.class)); assertTrue(variable.isInitialized()); - assertTrue(variable.isValueInitialized().asTensor().getBoolean(0)); + assertTrue(variable.isValueInitialized().asTensor().getBoolean()); assertEquals(1, variable.value().asTensor().getFloat(0, 0)); @@ -93,11 +94,11 @@ public void testGraph() { try (Session session = new Session(graph)) { - assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); + assertFalse(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean()); - session.run(tf.init()); + session.runInit(); - assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean(0)); + assertTrue(((TBool) session.runner().fetch(variable.isValueInitialized()).run().get(0)).getBoolean()); // test control deps (in-run assign) @@ -138,7 +139,8 @@ public void gradientTest() { Session sess = new Session(g)) { Ops tf = Ops.create(g); - Variable variable = tf.Variable(tf.placeholder(TFloat32.class)); + Placeholder initialValue = tf.placeholder(TFloat32.class); + Variable variable = tf.Variable(initialValue); Output y0 = tf.math.square(variable.value()).y(); Output y1 = tf.math.square(tf.math.square(variable.value())).y(); @@ -151,12 +153,14 @@ public void gradientTest() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + try (TFloat32 c = TFloat32.scalarOf(3.0f)) { + sess.runner().addTarget(tf.init()).feed(initialValue, c).run(); + try (AutoCloseableList outputs = + new AutoCloseableList<>(sess.runner().feed(initialValue, c).fetch(grads.dy(0)).run())) { - //TODO expected value may be wrong, check once C++ gradient exists - assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + //TODO expected value may be wrong, check once C++ gradient exists + assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + } } } }