diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 7abb451be2d..63b8bac28d7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -106,7 +106,6 @@ import org.tensorflow.op.core.IdentityN; import org.tensorflow.op.core.If; import org.tensorflow.op.core.ImmutableConst; -import org.tensorflow.op.core.Init; import org.tensorflow.op.core.InitializeTable; import org.tensorflow.op.core.InitializeTableFromTextFile; import org.tensorflow.op.core.InplaceAdd; @@ -295,6 +294,12 @@ import org.tensorflow.op.core.VariableShape; import org.tensorflow.op.core.Where; import org.tensorflow.op.core.While; +import org.tensorflow.op.core.XlaConvV2; +import org.tensorflow.op.core.XlaDotV2; +import org.tensorflow.op.core.XlaSetDynamicDimensionSize; +import org.tensorflow.op.core.XlaSpmdFullToShardShape; +import org.tensorflow.op.core.XlaSpmdShardToFullShape; +import org.tensorflow.op.core.XlaVariadicSort; import org.tensorflow.op.core.Zeros; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.types.TBool; @@ -366,20 +371,20 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -397,13 +402,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** @@ -1952,14 +1957,15 @@ public Constant constant(Shape shape, FloatDataBuffer data) { } /** - * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not - * fit in the target type. + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. * - * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) + * @param type the type of tensor to create. Must be concrete (i.e. not {@link + * org.tensorflow.types.family.TFloating}) * @param number the value of the tensor * @return a constant of the passed type - * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or - * unknown. + * @throws IllegalArgumentException if the type is abstract (i.e. {@link + * org.tensorflow.types.family.TFloating}) or unknown. */ public Constant constant(Class type, Number number) { return Constant.tensorOf(scope, type, number); @@ -1994,11 +2000,12 @@ public Constant constant(Class type, Shape shape, ByteDa } /** - * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without - * issue. + * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed + * afterwards without issue. * *

Note: this endpoint cannot be simply called {@code constant} since it will conflict with - * other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}. + * other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, + * FloatNdArray)}}. * * @param tensor a Tensor holding the constant value * @return a constant of the same data type as `tensor` @@ -2008,8 +2015,8 @@ public Constant constantOf(T tensor) { } /** - * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be - * truncated if it does not fit in the target type. + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code + * number} may be truncated if it does not fit in the target type. * * @param toMatch the operand providing the target type * @param number the value of the tensor @@ -2993,80 +3000,6 @@ public ImmutableConst immutableConst(Class dtype, Shape return ImmutableConst.create(scope, dtype, shape, memoryRegionName); } - /** - * Factory method to create an operation executing all initializers of a graph. - * - *

All initializers added to a graph via - * {@link org.tensorflow.op.core.Init#add(Scope, Op) tf.initAdd} are grouped together as a single - * unit of computation in the graph. This operation must then be added to any graph using one or - * more {@link Variable variables} and executed once before running the graph so the variable - * states are initialized properly.

- * - *

When the graph is built by the same process that is running the session, the initializers - * can be invoked by executing this single endpoint. For example:

- *
{@code
-   *  try (Graph g = new Graph()) {
-   *    Variable x = tf.variable(tf.constant(10));  // initAdd is called implicitly
-   *    Variable y = tf.variable(tf.constant(20));  // idem
-   *    Add z = tf.math.add(x, y);
-   *
-   *    try (Session s = new Session(g)) {
-   *      s.run(tf.init());  // initialize all variables
-   *
-   *      try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) {
-   *        assertEquals(30, t.data().getInt());
-   *      }
-   *    }
-   *  }
-   *  }
- * - *

When the graph is built by a separate process, the initializers can be invoked by running - * the init op by its name, which defaults to {@link org.tensorflow.op.core.Init#DEFAULT_NAME}. - * For example:

- *
{@code
-   *  // Building the model
-   *  try (Graph g = new Graph()) {
-   *    Variable x = tf.variable(tf.constant(10));  // initAdd is called implicitly
-   *    Variable y = tf.variable(tf.constant(20));  // idem
-   *    Add z = tf.withName("z").math.add(x, y);
-   *
-   *    tf.init();  // add variables initializers to the graph, as Init.DEFAULT_NAME
-   *    // ...exporting graph as a saved model...
-   *  }
-   *
-   *  ...
-   *
-   *  // Running the model
-   *  try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) {
-   *    model.session().run(Init.DEFAULT_NAME);
-   *
-   *    try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) {
-   *      assertEquals(30, t.data().getInt());
-   *    }
-   *  }
-   *  }
- * - * @return an op grouping all initializers added to the graph - * @throws IllegalArgumentException if the execution environment in scope is not a graph - */ - public Init init() { - return Init.create(scope); - } - - /** - * Register an op as an initializer of the graph. - * - *

Registered initializers are then grouped as a single unit of computation by adding - * and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph - * session. This is a no-op if executed in an eager session. - * - * @param initializer - * @see org.tensorflow.op.core.Init#create(Scope) init - */ - public void initAdd(Op initializer) { - Init.add(scope, initializer); - } - /** * Table initializer that takes two tensors for keys and values respectively. * @@ -7947,10 +7880,11 @@ public VarIsInitializedOp varIsInitializedOp(Operand resource) } /** - * Factory method to create a new Variable with it's initializer. - *

- * Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op - * does not work in an EagerSession. + * Factory method to create a new Variable with its initializer. Both the creation and assignment + * are done in the init scope. + * + *

Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op does not + * work in an EagerSession. * * @param init The op to use to initialise this variable. * @param options carries optional attributes values @@ -8143,6 +8077,37 @@ public Ops withSubScope(String childScopeName) { return new Ops(scope.withSubScope(childScopeName)); } + /** + * Returns an API that builds init operations. {@link #liftToInitScope(Operand)} will be called for all created operations. + *

+ * Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well, + * and are ignored when used as control dependencies. + * Additionally, this scope ignores any control dependencies. + *

+ * If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation. + * @see #liftToInitScope(Operand) + */ + public Ops withInitScope() { + return new Ops(scope.withInitScope()); + } + + /** + * Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs). + *

+ * Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well, + * and are ignored when used as control dependencies. + * Additionally, this scope ignores any control dependencies. + *

+ * If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation. + * @see ExecutionEnvironment#registerInitOp(Operation) + * + * @throws IllegalStateException if the op or one of its inputs can't be made an init op. + */ + public T liftToInitScope(T op) { + scope.env().registerInitOp(op.op()); + return op; + } + /** * Returns an API that uses the provided name for an op. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 9f87fd8b95e..fa705ca6ea6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -91,6 +91,31 @@ TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) { return outputHandles[outputIndex]; } + @Override + public int hashCode() { + return Long.valueOf(opHandle.address()).hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof EagerOperation)) { + return false; + } + EagerOperation that = (EagerOperation) o; + if (session != that.session) { + return false; + } + + if (opHandle == null || that.opHandle == null || opHandle.isNull() || that.opHandle.isNull()) { + // if they are the same object, we will already have returned + return false; + } + return opHandle.equals(that.opHandle); + } + @Override Shape shape(int outputIndex) { // If the tensor of this output has already been resolved, return its shape. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index e3283ee2ab3..cc47083bfce 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_Execute; @@ -53,6 +53,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; import org.tensorflow.proto.framework.DataType; /** @@ -60,17 +61,21 @@ */ final class EagerOperationBuilder implements OperationBuilder { - EagerOperationBuilder(EagerSession session, String type, String name) { + EagerOperationBuilder(EagerSession session, String type, String name, Scope scope) { this.session = session; this.type = type; this.name = name; + this.scope = scope; this.opHandle = allocate(session, type); } @Override public EagerOperation build() { + scope.apply(this); TFE_TensorHandle[] tensorHandles = execute(opHandle, session); - return new EagerOperation(session, opHandle, tensorHandles, type, name); + EagerOperation op = new EagerOperation(session, opHandle, tensorHandles, type, name); + scope.onOpCreated(op); + return op; } @Override @@ -250,6 +255,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { private final EagerSession session; private final String type; private final String name; + private final Scope scope; /** This value should be >= to the maximum number of outputs in any op */ private static final int MAX_OUTPUTS_PER_OP = 1000; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 84fe7675c40..f141e9dc551 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TFE_ContextAddFunction; @@ -29,7 +29,6 @@ import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; -import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; @@ -277,12 +276,12 @@ static void closeDefaultForTest() { } @Override - public OperationBuilder opBuilder(String type, String name) { + public OperationBuilder opBuilder(String type, String name, Scope scope) { checkSession(); if (!isOpEnabled(type)) { throw new IllegalArgumentException("Op " + type + " is not valid in eager mode."); } - return new EagerOperationBuilder(this, type, name); + return new EagerOperationBuilder(this, type, name, scope); } @Override @@ -322,7 +321,7 @@ public boolean isOpEnabled(String opType) { } @Override - public void checkInput(Op input) { + public void checkInput(Operation input) { if (!input.env().isEager()) { throw new IllegalArgumentException("Can't use graph operation " + input + " in eager mode."); } @@ -333,6 +332,16 @@ public Scope baseScope() { return baseScope; } + /** Noop, initialization is meaningless for eager sessions */ + @Override + public boolean isInitOp(Operation op) { + return false; + } + + /** Noop, initialization is meaningless for eager sessions */ + @Override + public void registerInitOp(Operation op) {} + TFE_Context nativeHandle() { checkSession(); return nativeHandle; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java index 6f50aeafe98..87745138f01 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import org.tensorflow.op.Op; @@ -30,12 +30,14 @@ enum Types { * Returns a builder to create a new {@link Operation}. * * @param type of the Operation (i.e., identifies the computation to be performed) - * @param name to refer to the created Operation in this environment scope. + * @param name to refer to the created Operation in this environment scope. Should already have + * been made unique. + * @param scope the scope that owns the created op * @return an {@link OperationBuilder} to create an Operation when {@link * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, * then some resources may leak. */ - OperationBuilder opBuilder(String type, String name); + OperationBuilder opBuilder(String type, String name, Scope scope); /** * Attach the function and its dependencies to this execution environment, allowing it to be @@ -64,7 +66,19 @@ default boolean isOpEnabled(String opType) { * @throws IllegalArgumentException if input can't be used as an input in this execution * environment. */ - void checkInput(Op input); + default void checkInput(Op input) { + checkInput(input.op()); + } + + /** + * Checks that {@code input} is valid to use as an input in this execution environment. Throws + * {@link IllegalArgumentException} if not. + * + * @param input The op to check + * @throws IllegalArgumentException if input can't be used as an input in this execution + * environment. + */ + void checkInput(Operation input); /** * Get the type of this environment (from the `Environments` enumeration. @@ -86,4 +100,30 @@ default boolean isGraph() { * prevent name collisions. */ Scope baseScope(); + + /** + * Get the execution environment to use for initialization. In most cases is {@code this}. + * + *

Should generally only be used internally. + */ + default ExecutionEnvironment initEnv() { + return this; + } + + /** + * Register an op and all of its inputs (and control inputs) as an initialization op. + * + *

Should generally only be used internally, prefer {@link + * org.tensorflow.op.Ops#withInitScope()}. + * + * @throws IllegalStateException if the op or one of its inputs can't be made an init op. + */ + void registerInitOp(Operation op); + + /** + * Get whether an op is an initialization op. + * + *

Should generally only be used internally. + */ + boolean isInitOp(Operation op); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index f3e712492b8..67377ed65b6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix; @@ -366,21 +366,12 @@ public synchronized Set subgraphFrom(Set> inputs) { return downstream; } - /** - * Returns a builder to add {@link Operation}s to the Graph. - * - * @param type of the Operation (i.e., identifies the computation to be performed) - * @param name to refer to the created Operation in the graph. - * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link - * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, - * then some resources may leak. - */ @Override - public GraphOperationBuilder opBuilder(String type, String name) { + public GraphOperationBuilder opBuilder(String type, String name, Scope scope) { if (!isOpEnabled(type)) { throw new IllegalArgumentException("Op " + type + " is not valid in graph mode."); } - return new GraphOperationBuilder(this, type, name); + return new GraphOperationBuilder(this, type, name, scope); } @Override @@ -478,7 +469,7 @@ public Types environmentType() { } @Override - public void checkInput(Op input) { + public void checkInput(Operation input) { if (input.env().isEager()) { throw new IllegalArgumentException( "Input " @@ -510,6 +501,8 @@ public void importGraphDef(GraphDef graphDef) throws IllegalArgumentException { importGraphDef(graphDef, ""); } + private static final String INIT_OP_BASE_NAME = "Init"; + /** * Import a representation of a TensorFlow graph. * @@ -522,35 +515,117 @@ public void importGraphDef(GraphDef graphDef, String prefix) throws IllegalArgum if (graphDef == null || prefix == null) { throw new IllegalArgumentException("graphDef and prefix cannot be null"); } + synchronized (nativeHandleLock) { importGraphDef(nativeHandle, graphDef, prefix); } + baseScope.refreshNames(); + + String initPrefix; + if (!prefix.isEmpty()) { + if (prefix.endsWith("/")) { + initPrefix = prefix + INIT_OP_BASE_NAME; + } else { + initPrefix = prefix + "/" + INIT_OP_BASE_NAME; + } + } else { + initPrefix = INIT_OP_BASE_NAME; + } + + operations() + .forEachRemaining( + op -> { + if (op.name().startsWith(initPrefix)) { + registerInitOp(op); + } + }); + } + + private synchronized void addInitOp() { + if (!newInitializers) { + return; + } + if (initializers.isEmpty()) { + return; + } + + baseScope.refreshNames(); + OperationBuilder builder = + baseScope().withInitScope().opBuilder(NoOp.OP_NAME, INIT_OP_BASE_NAME); + initializers.forEach(builder::addControlInput); + builder.build(); + newInitializers = false; } /** * Generate a representation of the Graph. * + *

If there are newly registered initializers (after the last {@link #toGraphDef()} call), this + * call adds an initialization operation to this graph that depends on them, so that they can be + * loaded properly if the graph def is later imported. + * * @see #importGraphDef(GraphDef) * @see #importGraphDef(GraphDef, String) */ public GraphDef toGraphDef() { + addInitOp(); synchronized (nativeHandleLock) { return toGraphDef(nativeHandle); } } + private boolean registerInitOpHelper(Operation op) { + if (isInitOp(op)) return false; + checkInput(op); + + if (!(op instanceof GraphOperation)) { + throw new IllegalStateException("Can't use a non-graph op as a graph's init op."); + } + GraphOperation graphOp = (GraphOperation) op; + + if (op.type().equals(Placeholder.OP_NAME)) { + throw new IllegalStateException("Can not make a placeholder " + op + " an init op."); + } + + for (GraphOperation controlInput : graphOp.controlInputs()) { + registerInitOpHelper(controlInput); + } + + for (Operand input : graphOp.inputs()) { + registerInitOpHelper(input.op()); + } + return initializers.add(op); + } + + @Override + public void registerInitOp(Operation op) { + if (registerInitOpHelper(op)) { + newInitializers = true; + } + } + + @Override + public boolean isInitOp(Operation op) { + return initializers.contains(op); + } + /** - * Adds an initializer to the graph initializer list. + * Returns a set of ops that will run all initializers added to the graph via {@link + * #registerInitOp(Operation)}. * - * @param initializer An initializer to add to the list. + *

Note that NoOps aren't included in this list, since any inputs or control dependencies are + * guaranteed to also be in this list, and including the no-ops wouldn't change the initialization + * result. */ - public synchronized void addInitializer(Op initializer) { - initializers.add(initializer); + public Set initializers() { + return initializers.stream() + .filter(x -> !x.type().equals(NoOp.OP_NAME)) + .collect(Collectors.toSet()); } - /** Returns all initializers added to the graph via {@link #addInitializer(Op)} */ - public List initializers() { - return Collections.unmodifiableList(initializers); + /** Get whether the graph has any initializers */ + public boolean hasInitializers() { + return !initializers.isEmpty(); } /** @@ -808,7 +883,8 @@ synchronized SaverDef saverDef() { private SaverDef saverDef; private final Scope baseScope; - private final List initializers = new ArrayList<>(); + private final Set initializers = Collections.synchronizedSet(new LinkedHashSet<>()); + private boolean newInitializers = false; // Related native objects (such as the TF_Operation object backing an Operation instance) // have a validity tied to that of the Graph. The handles to those native objects are not diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 53ab50db4b4..d1040469992 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.internal.c_api.global.tensorflow.TF_AddControlInput; @@ -58,6 +58,7 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; import org.tensorflow.proto.framework.AttrValue; import org.tensorflow.proto.framework.AttrValue.ListValue; import org.tensorflow.proto.framework.DataType; @@ -66,8 +67,9 @@ /** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ public final class GraphOperationBuilder implements OperationBuilder { - GraphOperationBuilder(Graph graph, String type, String name) { + GraphOperationBuilder(Graph graph, String type, String name, Scope scope) { this.graph = graph; + this.scope = scope; Graph.Reference r = graph.ref(); try { this.unsafeNativeHandle = allocate(r.nativeHandle(), type, name); @@ -83,10 +85,12 @@ public final class GraphOperationBuilder implements OperationBuilder { */ @Override public GraphOperation build() { + scope.apply(this); Graph.Reference r = graph.ref(); try { GraphOperation op = new GraphOperation(graph, finish(unsafeNativeHandle)); unsafeNativeHandle = null; + scope.onOpCreated(op); return op; } finally { r.close(); @@ -95,10 +99,7 @@ public GraphOperation build() { @Override public GraphOperationBuilder addControlInput(Operation control) { - if (!(control instanceof GraphOperation)) { - throw new IllegalArgumentException( - "Only GraphOperation instances can be used as control inputs"); - } + graph.checkInput(control); if (control.env() != graph) { throw new IllegalArgumentException( @@ -378,6 +379,7 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) { private TF_OperationDescription unsafeNativeHandle; private Graph graph; + private final Scope scope; private static void requireHandle(Pointer handle) { if (handle == null || handle.isNull()) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index fd0b390bc28..dfb8b7bbf60 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.tensorflow.Graph.resolveOutputs; @@ -22,8 +22,11 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -73,7 +76,8 @@ *

WARNING:A {@code Session} owns resources that must be explicitly freed by * invoking {@link #close()}. * - *

Instances of a Session are thread-safe. + *

Instances of a Session are thread-safe. Modifying a graph concurrently while other threads are + * using sessions of that graph is not safe. */ public final class Session implements AutoCloseable { @@ -97,6 +101,32 @@ public Session(Graph g) { * protocol buffer. */ public Session(Graph g, ConfigProto config) { + this(g, true, config); + } + + /** + * Construct and optionally initialize a new session with the associated {@link Graph}. + * + * @param g The {@link Graph} the created Session will operate on. + * @param autoInit Whether to initialize the session. + */ + public Session(Graph g, boolean autoInit) { + this(g, autoInit, null); + } + + /** + * Construct and optionally initialize a new session with the associated {@link Graph} and + * configuration options. + * + * @param g The {@link Graph} the created Session will operate on. + * @param autoInit Whether to initialize the session. + * @param config Configuration parameters for the session specified as a ConfigProto + * protocol buffer. + * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto + * protocol buffer. + */ + public Session(Graph g, boolean autoInit, ConfigProto config) { graph = g; Graph.Reference r = g.ref(); try { @@ -105,13 +135,19 @@ public Session(Graph g, ConfigProto config) { } finally { r.close(); } + this.autoInit = autoInit; } - /** Wrap an existing session with the associated {@link Graph}. */ + /** + * Wrap an existing session with the associated {@link Graph}. + * + *

Does not enable auto-init. + */ Session(Graph g, TF_Session nativeHandle) { graph = g; this.nativeHandle = nativeHandle; graphRef = g.ref(); + this.autoInit = false; } /** @@ -141,6 +177,41 @@ public void close() { } } + /** + * Execute any un-ran initializers. Will be done automatically unless disabled at session + * creation. + * + *

This runs any ops that have been created with an init scope that have not already been ran. + */ + public void initialize() { + Runner runner = runner(); + graph.initializers().stream().filter((x) -> !ranInits.contains(x)).forEach(runner::addTarget); + ranInits.clear(); + ranInits.addAll(graph.initializers()); + if (!runner.isEmpty()) { + runner.runNoInit(); + } + } + + /** + * Execute the graph's initializers, regardless of whether the session has been initialized. + * + *

This runs any ops that have been created with an init scope. + * + * @return this + */ + public Session forceInitialize() { + Set initializers = graph.initializers(); + if (!initializers.isEmpty()) { + Runner runner = runner(); + initializers.forEach(runner::addTarget); + runner.runNoInit(); + } + ranInits.clear(); + ranInits.addAll(graph.initializers()); + return this; + } + /** * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * @@ -373,6 +444,26 @@ public Runner setOptions(RunOptions options) { return this; } + /** True if there are no targets or fetches. */ + public boolean isEmpty() { + return targets.isEmpty() && outputs.isEmpty(); + } + + private void doInit() { + if (autoInit) { + initialize(); + } else { + graph + .initializers() + .forEach( + x -> { + if (!ranInits.contains(x)) + throw new IllegalStateException( + "Graph has un-ran initializers, but the session's autoInit is false. Run Session.initialize() before calling run()."); + }); + } + } + /** * Execute the graph fragments necessary to compute all requested fetches. * @@ -391,6 +482,11 @@ public Runner setOptions(RunOptions options) { * @return list of resulting tensors fetched by this session runner */ public List run() { + doInit(); + return runNoInit(); + } + + List runNoInit() { return runHelper(false).outputs; } @@ -405,6 +501,7 @@ public List run() { * @return list of resulting tensors fetched by this session runner, with execution metadata */ public Run runAndFetchMetadata() { + doInit(); return runHelper(true); } @@ -547,20 +644,6 @@ public Map run(Signature signature, Map argument return function(signature).call(arguments); } - /** - * Execute the graph's initializers. - * - *

This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. - */ - public void runInit() { - List initializers = graph.initializers(); - if (!initializers.isEmpty()) { - Runner runner = runner(); - initializers.forEach(runner::addTarget); - runner.run(); - } - } - /** * Saves the actual state of the variables of this session's graph. * @@ -583,7 +666,8 @@ public void save(String prefix) { } /** - * Restore the actual state of the variables of this session's graph. + * Restore the actual state of the variables of this session's graph. Counts as initialization, + * but can be done after other initializations. * *

{@code prefix} is the path where the files containing the variables state live, followed by * the filename prefix. For example, if {@code prefix} is set to @@ -600,7 +684,10 @@ public void restore(String prefix) { runner() .addTarget(saverDef.getRestoreOpName()) .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + .runNoInit(); + // TODO better way of doing this, only count as ran assignments to the restored variables. + ranInits.clear(); + ranInits.addAll(graph.initializers()); } /** @@ -634,6 +721,9 @@ Graph graph() { private TF_Session nativeHandle; private int numActiveRuns; + private final boolean autoInit; + private final Set ranInits = Collections.synchronizedSet(new LinkedHashSet<>()); + private static void requireHandle(Pointer handle) { if (handle == null || handle.isNull()) { throw new IllegalStateException("close() has been called on the Session"); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java index 903a12f66b2..25b194f8db3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/NameScope.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import java.util.HashMap; @@ -55,9 +55,7 @@ NameScope withName(String name) { private static final Pattern NAME_PATTERN = Pattern.compile("(.+)_(\\d+)", Pattern.DOTALL); - /** "Import" used names from a graph. Useful when adding to a loaded graph. */ - private NameScope withUsedFrom(ExecutionEnvironment env) { - + void importIdsFrom(ExecutionEnvironment env) { if (env instanceof Graph) { ((Graph) env) .operations() @@ -90,6 +88,11 @@ private NameScope withUsedFrom(ExecutionEnvironment env) { } }); } + } + + /** "Import" used names from a graph. Useful when adding to a loaded graph. */ + private NameScope withUsedFrom(ExecutionEnvironment env) { + importIdsFrom(env); return this; } @@ -134,7 +137,7 @@ private NameScope(String opPrefix, String opName, Map ids) { // // The second use of makeUnique("a") updates ids to "a" -> 2 // and returns "a_1", and so on. - private String makeUnique(String id) { + String makeUnique(String id) { if (!ids.containsKey(id)) { ids.put(id, 1); return id; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index 2aef70f6af0..b4705ea95a3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -1,23 +1,27 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.tensorflow.DeviceSpec; import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operation; import org.tensorflow.OperationBuilder; /** @@ -87,7 +91,7 @@ public final class Scope { * @param env The execution environment used by the scope. */ public Scope(ExecutionEnvironment env) { - this(env, new NameScope(env), new ArrayList<>(), DeviceSpec.newBuilder().build()); + this(env, new NameScope(env), new ArrayList<>(), DeviceSpec.newBuilder().build(), false); } /** Returns the execution environment used by this scope. */ @@ -110,7 +114,7 @@ public ExecutionEnvironment env() { */ public Scope withSubScope(String childScopeName) { return new Scope( - env, nameScope.withSubScope(childScopeName, env), controlDependencies, deviceSpec); + env, nameScope.withSubScope(childScopeName, env), controlDependencies, deviceSpec, isInit); } /** @@ -126,7 +130,7 @@ public Scope withSubScope(String childScopeName) { * @throws IllegalArgumentException if the name is invalid */ public Scope withName(String opName) { - return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec); + return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec, isInit); } /** @@ -150,7 +154,8 @@ public Scope withNameAsSubScope(String defaultName) { env, nameScope.withSubScope(nameScope.makeOpName(defaultName), env), controlDependencies, - deviceSpec); + deviceSpec, + isInit); } /** @@ -163,11 +168,18 @@ public Scope withNameAsSubScope(String defaultName) { * @return a new Scope that uses opName for operations. */ public Scope withDevice(DeviceSpec deviceSpec) { - return new Scope(env, nameScope, controlDependencies, deviceSpec); + return new Scope(env, nameScope, controlDependencies, deviceSpec, isInit); + } + + // TODO stop gradient recording in init scopes (once we have gradient recording) + + /** Get an extension of this scope that generates initialization ops. */ + public Scope withInitScope() { + return new Scope(env.initEnv(), nameScope, new ArrayList<>(), deviceSpec, true); } /** - * Create a unique name for an operator, using a provided default if necessary. + * Create a unique name for an operator and reserves it, using a provided default if necessary. * *

This is normally called only by operator building classes. * @@ -190,19 +202,50 @@ public String makeOpName(String defaultName) { return nameScope.makeOpName(defaultName); } + /** Makes a unique name from {@code id} and reserves it. */ + public String makeUnique(String id) { + return nameScope.makeUnique(id); + } + + /** + * Returns a builder to create a new {@link Operation}. + * + *

Note that {@code name} is automatically made unique. + * + * @param type of the Operation (i.e., identifies the computation to be performed) + * @param name to refer to the created Operation in this environment scope. Is uniquified. + * @return an {@link OperationBuilder} to create an Operation when {@link + * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, + * then some resources may leak. + */ + public OperationBuilder opBuilder(String type, String name) { + return env.opBuilder(type, makeOpName(name), this); + } + public static boolean isValidOpName(String name) { return NameScope.isValidName(name); } + /** + * Refresh the used name list (used for uniquifying names) from the underlying graph. + * + *

Should be used if you made changes to the graph from non-{@code Scope} APIs. + */ + public void refreshNames() { + nameScope.importIdsFrom(env); + } + private Scope( ExecutionEnvironment env, NameScope nameScope, - Iterable controlDependencies, - DeviceSpec deviceSpec) { + List controlDependencies, + DeviceSpec deviceSpec, + boolean isInit) { this.env = env; this.nameScope = nameScope; this.controlDependencies = controlDependencies; this.deviceSpec = deviceSpec; + this.isInit = isInit; } /** @@ -211,46 +254,88 @@ private Scope( *

Ops created with this scope will have a control edge from each of the provided controls. All * other properties are inherited from the current scope. * + *

Init ops will be ignored when used as control dependencies, they are assumed to be executed + * during session initialization. + * * @param controls control dependencies for ops created with the returned scope * @return a new scope with the provided control dependencies */ public Scope withControlDependencies(Iterable controls) { - for (Op control : controls) { + return withControlDependencyOps( + StreamSupport.stream(controls.spliterator(), false) + .map(Op::op) + .collect(Collectors.toList())); + } + + /** + * Returns a new scope where added operations will have the provided control dependencies. + * + *

Ops created with this scope will have a control edge from each of the provided controls. All + * other properties are inherited from the current scope. + * + *

Init ops will be ignored when used as control dependencies, they are assumed to be executed + * during session initialization. + * + * @param controls control dependencies for ops created with the returned scope + * @return a new scope with the provided control dependencies + */ + public Scope withControlDependencyOps(Iterable controls) { + ArrayList toAdd = new ArrayList<>(); + for (Operation control : controls) { env.checkInput(control); + if (isInit && !env.isInitOp(control)) { + throw new IllegalArgumentException("Init scope can not have non-init control dependency."); + } + if (isInit || !env.isInitOp(control)) { + toAdd.add(control); + } } - return new Scope(env, nameScope, controls, deviceSpec); + + return new Scope(env, nameScope, toAdd, deviceSpec, isInit); } /** * Applies device specification and adds each Operand in controlDependencies as a control input to * the provided builder. * + *

Should only be used from {@link OperationBuilder} implementations + * * @param builder OperationBuilder to add control inputs and device specification to */ public OperationBuilder apply(OperationBuilder builder) { builder.setDevice(deviceSpec.toString()); - return applyControlDependencies(builder); + for (Operation control : controlDependencies) { + if (isInit || !env.isInitOp(control)) { + builder.addControlInput(control); + } + } + return builder; } /** - * Adds each Operand in controlDependencies as a control input to the provided builder. + * Handle op creation, like registering it as an init op if the scope is init. * - * @param builder OperationBuilder to add control inputs to + *

FOR INTERNAL USE ONLY */ - public OperationBuilder applyControlDependencies(OperationBuilder builder) { - for (Op control : controlDependencies) { - builder = builder.addControlInput(control.op()); + public void onOpCreated(Operation op) { + if (isInit) { + env.registerInitOp(op); } - return builder; } - private final ExecutionEnvironment env; - private final Iterable controlDependencies; - private final NameScope nameScope; - private final DeviceSpec deviceSpec; - /** Returns device string from the scope. */ public String getDeviceString() { return deviceSpec.toString(); } + + /** Get whether this scope is building init ops. */ + public boolean isInit() { + return isInit; + } + + private final ExecutionEnvironment env; + private final List controlDependencies; + private final NameScope nameScope; + private final DeviceSpec deviceSpec; + private final boolean isInit; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 0e3d30acc54..f9f6e00f0f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -58,8 +58,8 @@ /** * An operator producing a constant value. * - *

All endpoints of this operator are named `constant`, except those accepting vararg - * elements in parameter, which are named `array`. For example: + *

All endpoints of this operator are named `constant`, except those accepting vararg elements in + * parameter, which are named `array`. For example: * *

{@code
  * Ops tf = Ops.create();
@@ -126,8 +126,7 @@ public static Constant arrayOf(Scope scope, int... data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, int[][] data) {
-    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -142,8 +141,7 @@ public static Constant tensorOf(Scope scope, int[][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, int[][][] data) {
-    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -158,8 +156,7 @@ public static Constant tensorOf(Scope scope, int[][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, int[][][][] data) {
-    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -174,8 +171,7 @@ public static Constant tensorOf(Scope scope, int[][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, int[][][][][] data) {
-    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -190,8 +186,7 @@ public static Constant tensorOf(Scope scope, int[][][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, int[][][][][][] data) {
-    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -283,8 +278,8 @@ public static Constant arrayOf(Scope scope, float... data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, float[][] data) {
-    try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat32 value =
+        TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -299,8 +294,8 @@ public static Constant tensorOf(Scope scope, float[][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, float[][][] data) {
-    try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat32 value =
+        TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -315,8 +310,8 @@ public static Constant tensorOf(Scope scope, float[][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, float[][][][] data) {
-    try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat32 value =
+        TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -331,8 +326,8 @@ public static Constant tensorOf(Scope scope, float[][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, float[][][][][] data) {
-    try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat32 value =
+        TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -347,8 +342,8 @@ public static Constant tensorOf(Scope scope, float[][][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, float[][][][][][] data) {
-    try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat32 value =
+        TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -440,8 +435,8 @@ public static Constant arrayOf(Scope scope, double... data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, double[][] data) {
-    try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat64 value =
+        TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -456,8 +451,8 @@ public static Constant tensorOf(Scope scope, double[][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, double[][][] data) {
-    try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat64 value =
+        TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -472,8 +467,8 @@ public static Constant tensorOf(Scope scope, double[][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, double[][][][] data) {
-    try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat64 value =
+        TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -488,8 +483,8 @@ public static Constant tensorOf(Scope scope, double[][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, double[][][][][] data) {
-    try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat64 value =
+        TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -504,8 +499,8 @@ public static Constant tensorOf(Scope scope, double[][][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, double[][][][][][] data) {
-    try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(
-        data, t))) {
+    try (TFloat64 value =
+        TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -582,8 +577,7 @@ public static Constant vectorOf(Scope scope, long[] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, long[][] data) {
-    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -613,8 +607,7 @@ public static Constant arrayOf(Scope scope, long... data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, long[][][] data) {
-    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -629,8 +622,7 @@ public static Constant tensorOf(Scope scope, long[][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, long[][][][] data) {
-    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -645,8 +637,7 @@ public static Constant tensorOf(Scope scope, long[][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, long[][][][][] data) {
-    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -661,8 +652,7 @@ public static Constant tensorOf(Scope scope, long[][][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, long[][][][][][] data) {
-    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -754,8 +744,7 @@ public static Constant arrayOf(Scope scope, boolean... data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, boolean[][] data) {
-    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -770,8 +759,7 @@ public static Constant tensorOf(Scope scope, boolean[][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, boolean[][][] data) {
-    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -786,8 +774,7 @@ public static Constant tensorOf(Scope scope, boolean[][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, boolean[][][][] data) {
-    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -802,8 +789,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, boolean[][][][][] data) {
-    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -818,8 +804,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, boolean[][][][][][] data) {
-    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data,
-        t))) {
+    try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) {
       return create(scope, value);
     }
   }
@@ -911,8 +896,7 @@ public static Constant arrayOf(Scope scope, byte... data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, byte[][] data) {
-    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data,
-        d))) {
+    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) {
       return create(scope, value);
     }
   }
@@ -927,8 +911,7 @@ public static Constant tensorOf(Scope scope, byte[][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, byte[][][] data) {
-    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data,
-        d))) {
+    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) {
       return create(scope, value);
     }
   }
@@ -943,8 +926,7 @@ public static Constant tensorOf(Scope scope, byte[][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, byte[][][][] data) {
-    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data,
-        d))) {
+    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) {
       return create(scope, value);
     }
   }
@@ -959,8 +941,7 @@ public static Constant tensorOf(Scope scope, byte[][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, byte[][][][][] data) {
-    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data,
-        d))) {
+    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) {
       return create(scope, value);
     }
   }
@@ -975,8 +956,7 @@ public static Constant tensorOf(Scope scope, byte[][][][][] data) {
    */
   @Endpoint
   public static Constant tensorOf(Scope scope, byte[][][][][][] data) {
-    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data,
-        d))) {
+    try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) {
       return create(scope, value);
     }
   }
@@ -1027,8 +1007,8 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer
    *     buffer
    */
   @Endpoint
-  public static  Constant tensorOf(Scope scope, Class type, Shape shape,
-      ByteDataBuffer data) {
+  public static  Constant tensorOf(
+      Scope scope, Class type, Shape shape, ByteDataBuffer data) {
     try (T value = Tensor.of(type, shape, data)) {
       return create(scope, value);
     }
@@ -1262,8 +1242,8 @@ public static Constant tensorOf(Scope scope, Shape shape, DataBuffer tensorOf(Scope scope, Charset charset, Shape shape,
-      DataBuffer data) {
+  public static Constant tensorOf(
+      Scope scope, Charset charset, Shape shape, DataBuffer data) {
     try (TString value = TString.tensorOf(charset, shape, data)) {
       return create(scope, value);
     }
@@ -1283,18 +1263,20 @@ public static Constant tensorOf(Scope scope, Shape shape) {
   }
 
   /**
-   * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
-   * fit in the target type.
+   * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be
+   * truncated if it does not fit in the target type.
    *
-   * @param type the type of tensor to create.  Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
+   * @param type the type of tensor to create. Must be concrete (i.e. not {@link
+   *     org.tensorflow.types.family.TFloating})
    * @param number the value of the tensor
    * @return a constant of the passed type
-   * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
-   * unknown.
+   * @throws IllegalArgumentException if the type is abstract (i.e. {@link
+   *     org.tensorflow.types.family.TFloating}) or unknown.
    */
   @SuppressWarnings("unchecked")
   @Endpoint
-  public static  Constant tensorOf(Scope scope, Class type, Number number) {
+  public static  Constant tensorOf(
+      Scope scope, Class type, Number number) {
     if (type.equals(TBfloat16.class)) {
       try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) {
         return (Constant) create(scope, tensor);
@@ -1324,13 +1306,14 @@ public static  Constant tensorOf(Scope scope, Class typ
         return (Constant) create(scope, tensor);
       }
     } else {
-      throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type.");
+      throw new IllegalArgumentException(
+          "Tensor type " + type + " is an abstract or unknown numeric type.");
     }
   }
 
   /**
-   * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
-   * truncated if it does not fit in the target type.
+   * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code
+   * number} may be truncated if it does not fit in the target type.
    *
    * @param toMatch the operand providing the target type
    * @param number the value of the tensor
@@ -1339,16 +1322,18 @@ public static  Constant tensorOf(Scope scope, Class typ
    * @see Ops#constant(Class, Number)
    */
   @Endpoint(name = "constantOfSameType")
-  public static  Constant tensorOfSameType(Scope scope, Operand toMatch, Number number) {
+  public static  Constant tensorOfSameType(
+      Scope scope, Operand toMatch, Number number) {
     return tensorOf(scope, toMatch.type(), number);
   }
 
   /**
-   * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without
-   * issue.
+   * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed
+   * afterwards without issue.
    *
    * 

Note: this endpoint cannot be simply called {@code constant} since it will conflict with - * other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}. + * other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, + * FloatNdArray)}}. * * @param scope is a scope used to add the underlying operation. * @param tensor a Tensor holding the constant value @@ -1356,13 +1341,11 @@ public static Constant tensorOfSameType(Scope scope, Oper */ @Endpoint(name = "constantOf") public static Constant create(Scope scope, T tensor) { - OperationBuilder builder = scope - .env() - .opBuilder(OP_NAME, scope.makeOpName(OP_NAME)) - .setAttr("value", tensor) - .setAttr("dtype", tensor.dataType()); - - scope.apply(builder); + OperationBuilder builder = + scope + .opBuilder(OP_NAME, OP_NAME) + .setAttr("value", tensor) + .setAttr("dtype", tensor.dataType()); return new Constant<>(builder.build()); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index 59682777966..fa2c8d1f2bd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -22,32 +22,31 @@ import org.tensorflow.types.family.TType; /** - * Container class for core methods which add or perform several operations - * and return one of them. + * Container class for core methods which add or perform several operations and return one of them. */ @Operator public abstract class Helpers { - /** - * This class contains static factories. - */ - private Helpers() {} + /** This class contains static factories. */ + private Helpers() {} - /** - * Factory method to create a new Variable with it's initializer. - *

- * Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op - * does not work in an EagerSession. - * @param scope current scope - * @param init The op to use to initialise this variable. - * @param options carries optional attributes values - * @return a new instance of Variable - */ - @Endpoint(name = "variable") - public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { - Variable newVar = Variable.create(scope, init.shape(), init.type(), options); - Assign assignOp = Assign.create(scope, newVar, init); - Init.add(scope, assignOp); - return newVar; - } + /** + * Factory method to create a new Variable with its initializer. Both the creation and assignment + * are done in the init scope. + * + *

Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op does not + * work in an EagerSession. + * + * @param scope current scope + * @param init The op to use to initialise this variable. + * @param options carries optional attributes values + * @return a new instance of Variable + */ + @Endpoint(name = "variable") + public static Variable createVariableWithInit( + Scope scope, Operand init, Variable.Options... options) { + Variable newVar = Variable.create(scope.withInitScope(), init.shape(), init.type(), options); + Assign assignOp = Assign.create(scope.withInitScope(), newVar, init); + return newVar; + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java deleted file mode 100644 index b05eb07c8ca..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Init.java +++ /dev/null @@ -1,110 +0,0 @@ -package org.tensorflow.op.core; - -import org.tensorflow.ExecutionEnvironment; -import org.tensorflow.Graph; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.op.Op; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; - -@Operator -public final class Init extends RawOp { - - public static final String DEFAULT_NAME = "init"; - - /** - * Factory method to create an operation executing all initializers of a graph. - * - *

All initializers added to a graph via - * {@link org.tensorflow.op.core.Init#add(Scope, Op) tf.initAdd} are grouped together as a single - * unit of computation in the graph. This operation must then be added to any graph using one or - * more {@link Variable variables} and executed once before running the graph so the variable - * states are initialized properly.

- * - *

When the graph is built by the same process that is running the session, the initializers - * can be invoked by executing this single endpoint. For example:

- *
{@code
-   * try (Graph g = new Graph()) {
-   *   Variable x = tf.variable(tf.constant(10));  // initAdd is called implicitly
-   *   Variable y = tf.variable(tf.constant(20));  // idem
-   *   Add z = tf.math.add(x, y);
-   *
-   *   try (Session s = new Session(g)) {
-   *     s.run(tf.init());  // initialize all variables
-   *
-   *     try (TInt32 t = (TInt32)s.runner().fetch(z).run().get(0)) {
-   *       assertEquals(30, t.data().getInt());
-   *     }
-   *   }
-   * }
-   * }
- * - *

When the graph is built by a separate process, the initializers can be invoked by running - * the init op by its name, which defaults to {@link org.tensorflow.op.core.Init#DEFAULT_NAME}. - * For example:

- *
{@code
-   * // Building the model
-   * try (Graph g = new Graph()) {
-   *   Variable x = tf.variable(tf.constant(10));  // initAdd is called implicitly
-   *   Variable y = tf.variable(tf.constant(20));  // idem
-   *   Add z = tf.withName("z").math.add(x, y);
-   *
-   *   tf.init();  // add variables initializers to the graph, as Init.DEFAULT_NAME
-   *   // ...exporting graph as a saved model...
-   * }
-   *
-   * ...
-   *
-   * // Running the model
-   * try (SavedModelBundle model = SavedModelBundle.load("/path/to/model", "train")) {
-   *   model.session().run(Init.DEFAULT_NAME);
-   *
-   *   try (TInt32 t = (TInt32)s.runner().fetch("z").run().get(0)) {
-   *     assertEquals(30, t.data().getInt());
-   *   }
-   * }
-   * }
- * - * @param scope current scope - * @return an op grouping all initializers added to the graph - * @throws IllegalArgumentException if the execution environment in scope is not a graph - */ - @Endpoint(name = "init") - public static Init create(Scope scope) { - ExecutionEnvironment exEnv = scope.env(); - if (!(exEnv instanceof Graph)) { - throw new IllegalArgumentException("init is only supported on Graph sessions."); - } - Graph graph = (Graph)exEnv; - OperationBuilder opBuilder = scope.env().opBuilder("NoOp", scope.makeOpName(DEFAULT_NAME)); - scope.withControlDependencies(graph.initializers()).applyControlDependencies(opBuilder); - return new Init(opBuilder.build()); - } - - /** - * Register an op as an initializer of the graph. - * - *

Registered initializers are then grouped as a single unit of computation by adding - * and executing an {@link org.tensorflow.op.core.Init#create(Scope) init} operation from a graph - * session. This is a no-op if executed in an eager session. - * - * @param scope - * @param initializer - * @see org.tensorflow.op.core.Init#create(Scope) init - */ - @Endpoint(name = "initAdd") - public static void add(Scope scope, Op initializer) { - ExecutionEnvironment exEnv = scope.env(); - - if (exEnv.isGraph()) { - ((Graph) exEnv).addInitializer(initializer); - } - } - - private Init(Operation operation) { - super(operation); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 64c33f451fb..37875f151c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -1,18 +1,18 @@ /* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -21,7 +21,6 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.math.Add; import org.tensorflow.op.math.Sub; @@ -33,7 +32,6 @@ public class ConcreteFunctionTest { private static Signature plusFive(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); - Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index b694e0e5a39..e67312a9847 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.fail; @@ -31,7 +31,7 @@ public void failToCreateIfSessionIsClosed() { EagerSession session = EagerSession.create(); session.close(); try { - new EagerOperationBuilder(session, "Add", "add"); + new EagerOperationBuilder(session, "Add", "add", session.baseScope()); fail(); } catch (IllegalStateException e) { // expected @@ -42,7 +42,7 @@ public void failToCreateIfSessionIsClosed() { public void failToBuildOpIfSessionIsClosed() { EagerOperationBuilder opBuilder; try (EagerSession session = EagerSession.create()) { - opBuilder = new EagerOperationBuilder(session, "Empty", "empty"); + opBuilder = new EagerOperationBuilder(session, "Empty", "empty", session.baseScope()); } try { opBuilder.setAttr("dtype", DataType.DT_FLOAT); @@ -140,6 +140,6 @@ public void setAttrs() { } private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { - return new EagerOperationBuilder(session, type, name); + return new EagerOperationBuilder(session, type, name, session.baseScope()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 38714b86599..e2dc82f4c48 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -26,9 +26,7 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** - * Unit tests for {@link EagerOperation} class. - */ +/** Unit tests for {@link EagerOperation} class. */ public class EagerOperationTest { @Test @@ -50,7 +48,7 @@ public void failToCreateIfSessionIsClosed() { @Test public void outputDataTypeAndShape() { try (EagerSession session = EagerSession.create(); - TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { EagerOperation op = opBuilder(session, "Const", "OutputAttrs") .setAttr("dtype", t.dataType()) @@ -71,7 +69,7 @@ public void outputTensor() { .addInput(tf.constant(2).asOutput()) .addInput(tf.constant(4).asOutput()) .build(); - assertEquals(6, ((TInt32)add.tensor(0)).getInt()); + assertEquals(6, ((TInt32) add.tensor(0)).getInt()); // Validate that we retrieve the right shape and datatype from the tensor // that has been resolved @@ -84,12 +82,12 @@ public void outputTensor() { public void inputAndOutputListLengths() { try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); - Output c1 = tf.constant(new float[]{1f, 2f}).asOutput(); - Output c2 = tf.constant(new float[]{3f, 4f}).asOutput(); + Output c1 = tf.constant(new float[] {1f, 2f}).asOutput(); + Output c2 = tf.constant(new float[] {3f, 4f}).asOutput(); EagerOperation acc = opBuilder(session, "AddN", "InputListLength") - .addInputList(new Output[]{c1, c2}) + .addInputList(new Output[] {c1, c2}) .build(); assertEquals(2, acc.inputListLength("inputs")); assertEquals(1, acc.outputListLength("sum")); @@ -125,8 +123,8 @@ public void numOutputs() { Ops tf = Ops.create(session); EagerOperation op = opBuilder(session, "UniqueWithCountsV2", "unq") - .addInput(tf.constant(new int[]{1, 2, 1}).asOutput()) - .addInput(tf.constant(new int[]{0}).asOutput()) + .addInput(tf.constant(new int[] {1, 2, 1}).asOutput()) + .addInput(tf.constant(new int[] {0}).asOutput()) .setAttr("out_idx", DataType.DT_INT32) .build(); assertEquals(3, op.numOutputs()); @@ -189,6 +187,6 @@ public void outputIndexOutOfBounds() { } private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { - return new EagerOperationBuilder(session, type, name); + return new EagerOperationBuilder(session, type, name, session.baseScope()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java index 77325d50dcc..ba5d3c54a6c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java @@ -125,7 +125,7 @@ public void defaultSession() throws Exception { private static void buildOp(EagerSession s) { // Creating an operation is a safe point for resource cleanup try { - s.opBuilder("Const", "Const"); + s.baseScope().opBuilder("Const", "Const"); } catch (UnsupportedOperationException e) { // TODO (karlllessard) remove this exception catch when EagerOperationBuilder is implemented } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index d0e79534d2c..84e1e56df56 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -35,7 +35,10 @@ public void failOnUseAfterBuild() { try (Graph g = new Graph(); TInt32 t = TInt32.scalarOf(1)) { OperationBuilder b = - g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); + g.baseScope() + .opBuilder("Const", "Const") + .setAttr("dtype", t.dataType()) + .setAttr("value", t); b.build(); try { b.setAttr("dtype", t.dataType()); @@ -50,7 +53,11 @@ public void failOnUseAfterGraphClose() { OperationBuilder b = null; try (Graph g = new Graph(); TInt32 t = TInt32.scalarOf(1)) { - b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); + b = + g.baseScope() + .opBuilder("Const", "Const") + .setAttr("dtype", t.dataType()) + .setAttr("value", t); } try { b.build(); @@ -72,7 +79,8 @@ public void setAttr() { Ops tf = Ops.create(g); // dtype, tensor attributes. try (TInt32 t = TInt32.scalarOf(1)) { - g.opBuilder("Const", "DataTypeAndTensor") + g.baseScope() + .opBuilder("Const", "DataTypeAndTensor") .setAttr("dtype", t.dataType()) .setAttr("value", t) .build() @@ -80,20 +88,23 @@ public void setAttr() { assertTrue(hasNode(g, "DataTypeAndTensor")); } // string, bool attributes. - g.opBuilder("Abort", "StringAndBool") + g.baseScope() + .opBuilder("Abort", "StringAndBool") .setAttr("error_msg", "SomeErrorMessage") .setAttr("exit_without_error", false) .build(); assertTrue(hasNode(g, "StringAndBool")); // int (TF "int" attributes are 64-bit signed, so a Java long). - g.opBuilder("RandomUniform", "Int") + g.baseScope() + .opBuilder("RandomUniform", "Int") .addInput(tf.array(1).asOutput()) .setAttr("seed", 10) .setAttr("dtype", DataType.DT_FLOAT) .build(); assertTrue(hasNode(g, "Int")); // list(int) - g.opBuilder("MaxPool", "IntList") + g.baseScope() + .opBuilder("MaxPool", "IntList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) .setAttr("ksize", new long[] {1, 1, 1, 1}) .setAttr("strides", new long[] {1, 1, 1, 1}) @@ -101,7 +112,8 @@ public void setAttr() { .build(); assertTrue(hasNode(g, "IntList")); // list(float) - g.opBuilder("FractionalMaxPool", "FloatList") + g.baseScope() + .opBuilder("FractionalMaxPool", "FloatList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) .build(); @@ -115,7 +127,8 @@ public void setAttr() { public void setAttrShape() { try (Graph g = new Graph()) { Output n = - g.opBuilder("Placeholder", "unknown") + g.baseScope() + .opBuilder("Placeholder", "unknown") .setAttr("dtype", DataType.DT_FLOAT) .setAttr("shape", Shape.unknown()) .build() @@ -124,7 +137,8 @@ public void setAttrShape() { assertEquals(DataType.DT_FLOAT, n.dataType()); n = - g.opBuilder("Placeholder", "batch_of_vectors") + g.baseScope() + .opBuilder("Placeholder", "batch_of_vectors") .setAttr("dtype", DataType.DT_FLOAT) .setAttr("shape", Shape.of(-1, 784)) .build() @@ -158,11 +172,13 @@ public void addControlInput() { Ops tf = Ops.create(g); Output placeholder = tf.placeholder(TBool.class).asOutput(); GraphOperation check = - g.opBuilder("Assert", "assert") - .addInput(placeholder) - .addInputList(new Output[] {placeholder}) - .build(); - Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); + (GraphOperation) + g.baseScope() + .opBuilder("Assert", "assert") + .addInput(placeholder) + .addInputList(new Output[] {placeholder}) + .build(); + Operation noop = g.baseScope().opBuilder("NoOp", "noop").addControlInput(check).build(); // No problems when the Assert check succeeds s.runner().feed(placeholder, yes).addTarget(noop).run(); @@ -183,7 +199,8 @@ private static void testSetAttrShapeList(Shape[] shapes) { Ops tf = Ops.create(g); int[][] matrix = new int[][] {{0, 0}, {0, 0}}; Output queue = - g.opBuilder("FIFOQueue", "queue") + g.baseScope() + .opBuilder("FIFOQueue", "queue") .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) .setAttr("shapes", shapes) .build() @@ -192,7 +209,8 @@ private static void testSetAttrShapeList(Shape[] shapes) { Output c1 = tf.constant(matrix).asOutput(); Output c2 = tf.constant(new int[][][] {matrix, matrix}).asOutput(); Operation enqueue = - g.opBuilder("QueueEnqueue", "enqueue") + g.baseScope() + .opBuilder("QueueEnqueue", "enqueue") .addInput(queue) .addInputList(new Output[] {c1, c2}) .build(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index dddd5867d33..464701306f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -19,18 +19,17 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.linalg.MatMul; import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.GraphDef; @@ -46,11 +45,12 @@ public void graphDefRoundTrip() { // Create a graph for A * X + B try (Graph g = new Graph()) { Ops tf = Ops.create(g); - tf.withName("Y").linalg.matMul( - tf.withName("A").constant(new int[2][2]), - tf.withName("X").placeholder(TInt32.class), - MatMul.transposeA(true).transposeB(false) - ); + tf.withName("Y") + .linalg + .matMul( + tf.withName("A").constant(new int[2][2]), + tf.withName("X").placeholder(TInt32.class), + MatMul.transposeA(true).transposeB(false)); graphDef = g.toGraphDef(); } // Import the GraphDef and find all the nodes. @@ -64,6 +64,60 @@ public void graphDefRoundTrip() { } } + @Test + public void graphDefRoundTripWithInit() { + GraphDef graphDef; + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Ops init = tf.withInitScope(); + + Operand variable = init.variable(init.constant(4)); + Operand result = tf.withName("result").math.add(variable, tf.constant(2)); + graphDef = g.toGraphDef(); + } + + try (Graph g = new Graph()) { + g.importGraphDef(graphDef); + + Ops tf = Ops.create(g); + Ops init = tf.withInitScope(); + + Operand variable2 = init.withName("var2").variable(init.constant(4)); + + try (Session s = new Session(g, true)) { + List results = s.runner().fetch("result").fetch("var2").run(); + TInt32 result = (TInt32) results.get(0); + assertEquals(6, result.getInt()); + + TInt32 var2Result = (TInt32) results.get(1); + assertEquals(4, var2Result.getInt()); + + results.forEach(Tensor::close); + } + } + } + + @Test + public void graphDefRoundTripWithInitAndPrefix() { + GraphDef graphDef; + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Ops init = tf.withInitScope(); + + Operand variable = init.variable(init.constant(4)); + Operand result = tf.withName("result").math.add(variable, tf.constant(2)); + graphDef = g.toGraphDef(); + } + + try (Graph g = new Graph()) { + g.importGraphDef(graphDef, "test"); + try (Session s = new Session(g); + TInt32 result = (TInt32) s.runner().fetch("test/result").run().get(0)) { + assertEquals(6, result.getInt()); + } + } + } + // Helper function whose implementation is based on knowledge of how // TestUtil.transpose_A_times_X is implemented. private static void validateImportedGraph(Graph g, String prefix) { @@ -123,17 +177,21 @@ public void completeSubgraph() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); Operand control = tf.constant(0); - Operand a = tf.withControlDependencies(Collections.singletonList(control)).constant(1); + Operand a = + tf.withControlDependencies(Collections.singletonList(control)).constant(1); Operand b = tf.constant(2); Operand c = tf.constant(3); Operand d = tf.math.add(a, b); Operand output = tf.math.mul(d, c); - Set subgraph = g - .completeSubgraph(new LinkedHashSet<>(Arrays.asList(control, a, b, c)), Collections.singleton(output)); + Set subgraph = + g.completeSubgraph( + new LinkedHashSet<>(Arrays.asList(control, a, b, c)), Collections.singleton(output)); - assertEquals(new LinkedHashSet<>(Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())), + assertEquals( + new LinkedHashSet<>( + Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())), subgraph); } } @@ -143,17 +201,20 @@ public void completeSubgraphWithConstants() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); Operand control = tf.constant(0); - Operand a = tf.withControlDependencies(Collections.singletonList(control)).constant(1); + Operand a = + tf.withControlDependencies(Collections.singletonList(control)).constant(1); Operand b = tf.constant(2); Operand c = tf.constant(3); Operand d = tf.math.add(a, b); Operand output = tf.math.mul(d, c); - Set subgraph = g - .completeSubgraph(Collections.emptySet(), Collections.singleton(output)); + Set subgraph = + g.completeSubgraph(Collections.emptySet(), Collections.singleton(output)); - assertEquals(new LinkedHashSet<>(Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())), + assertEquals( + new LinkedHashSet<>( + Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())), subgraph); } } @@ -191,7 +252,7 @@ public void addGradientsToGraph() { Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); - + Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); @@ -202,22 +263,23 @@ public void addGradientsToGraph() { assertEquals(2, grads1.length); assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); - + try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { - + AutoCloseableList outputs = + new AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + assertEquals(3, outputs.size()); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); - assertEquals(1.0f, ((TFloat32)outputs.get(2)).getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); + assertEquals(1.0f, ((TFloat32) outputs.get(2)).getFloat(), 0.0f); } } } @@ -238,11 +300,7 @@ public void addGradientSumsToGraph() { assertEquals(DataType.DT_FLOAT, grad[0].dataType()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - TFloat32 output = (TFloat32)s.runner() - .feed(x, c) - .fetch(grad[0]) - .run() - .get(0)) { + TFloat32 output = (TFloat32) s.runner().feed(x, c).fetch(grad[0]).run().get(0)) { assertEquals(114.0f, output.getFloat(), 0.0f); } } @@ -257,7 +315,7 @@ public void addGradientsWithInitialValuesToGraph() { Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); - + Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); @@ -269,11 +327,7 @@ public void addGradientsWithInitialValuesToGraph() { assertEquals(DataType.DT_FLOAT, grad1[0].dataType()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - TFloat32 output = (TFloat32)s.runner() - .feed(x, c) - .fetch(grad1[0]) - .run() - .get(0)) { + TFloat32 output = (TFloat32) s.runner().feed(x, c).fetch(grad1[0]).run().get(0)) { assertEquals(108.0f, output.getFloat(), 0.0f); } } @@ -316,24 +370,22 @@ public void buildWhileLoopSingleInput() { Output input = tf.placeholder(TInt32.class).output(); @SuppressWarnings("unchecked") - Output[] loopOutputs = g.whileLoop( - toArray(input), - (condGraph, condInputs, condOutputs) -> { - Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); - }, - (bodyGraph, bodyInputs, bodyOutputs) -> { - Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); - }, - "test_loop"); + Output[] loopOutputs = + g.whileLoop( + toArray(input), + (condGraph, condInputs, condOutputs) -> { + Ops tfc = Ops.create(condGraph); + condOutputs[0] = + tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); + }, + (bodyGraph, bodyInputs, bodyOutputs) -> { + Ops tfb = Ops.create(bodyGraph); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); + }, + "test_loop"); try (TInt32 c = TInt32.scalarOf(2); - TInt32 output = (TInt32)s.runner() - .feed(input, c) - .fetch(loopOutputs[0]) - .run() - .get(0)) { + TInt32 output = (TInt32) s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) { assertEquals(16, output.getInt()); // ((2^2)^2) } } @@ -350,18 +402,20 @@ public void buildWhileLoopMultipleInputs() { Output[] inputs = toArray(input1, input2); @SuppressWarnings("unchecked") - Output[] loopOutputs = g.whileLoop( - inputs, - (condGraph, condInputs, condOutputs) -> { - Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); - }, - (bodyGraph, bodyInputs, bodyOutputs) -> { - Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); - bodyOutputs[1] = tfb.math.square((Output)bodyInputs[1]).y(); - }, - "test_loop"); + Output[] loopOutputs = + g.whileLoop( + inputs, + (condGraph, condInputs, condOutputs) -> { + Ops tfc = Ops.create(condGraph); + condOutputs[0] = + tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); + }, + (bodyGraph, bodyInputs, bodyOutputs) -> { + Ops tfb = Ops.create(bodyGraph); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); + bodyOutputs[1] = tfb.math.square((Output) bodyInputs[1]).y(); + }, + "test_loop"); try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); @@ -374,8 +428,8 @@ public void buildWhileLoopMultipleInputs() { .fetch(loopOutputs[1]) .run())) { assertEquals(2, outputs.size()); - assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) - assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) + assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2) + assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2) } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 1561842a689..ab1f5c17dd4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -36,7 +36,6 @@ import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Identity; -import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.op.core.Variable; @@ -112,7 +111,6 @@ public void exportMultipleFunctions() throws IOException { try (Session s = new Session(g); ) { SessionFunction f1 = SessionFunction.create(f1Signature, s); SessionFunction f2 = SessionFunction.create(f2Signature, s); - s.runInit(); try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] {2, 2})); TFloat32 t = (TFloat32) f1.call(x)) { reducedSum = t.getFloat(); @@ -154,7 +152,7 @@ public void exportFunctionWithVariables() throws IOException { Ops tf = Ops.create(g); SessionFunction f = session.function(buildGraphWithVariables(tf, xyShape)); // Init variable state by running the Init operation directly - session.runInit(); + session.initialize(); // Call the graph and remember the result of computation for later try (TFloat32 xTensor = TFloat32.tensorOf(xValue); @@ -230,8 +228,8 @@ public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOExcept Signature f2Signature = buildIdentityGraph(tf, "identity"); SessionFunction f1 = s1.function(f1Signature); SessionFunction f2 = s2.function(f2Signature); - s1.runInit(); - s2.runInit(); + s1.initialize(); + s2.initialize(); try { SavedModelBundle.exporter(testFolder.toString()).withFunction(f1).withFunction(f2).export(); fail(); @@ -251,7 +249,7 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti try (Session s = new Session(g); ) { SessionFunction f1 = SessionFunction.create(f1Signature, s); SessionFunction f2 = SessionFunction.create(f2Signature, s); - s.runInit(); + s.initialize(); try { SavedModelBundle.exporter(testFolder.toString()).withFunctions(f1, f2).export(); fail(); @@ -324,10 +322,10 @@ public void pythonTfFunction() { private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); Variable y = - tf.withName("variable") + tf.withInitScope() + .withName("variable") .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); - Init init = tf.init(); return Signature.builder().input("input", x).output("reducedSum", z).build(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 8a3e64c3336..3575da6c8c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -32,7 +32,6 @@ import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; import org.tensorflow.op.core.Variable; import org.tensorflow.op.linalg.MatMul; @@ -184,13 +183,13 @@ public void runInit() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable var1 = tf.variable(Shape.scalar(), TInt32.class); - tf.initAdd(tf.assign(var1, tf.constant(10))); - Variable var2 = tf.variable(tf.constant(20)); + Variable var1 = tf.withInitScope().variable(Shape.scalar(), TInt32.class); + tf.withInitScope().assign(var1, tf.withInitScope().constant(10)); + Variable var2 = tf.withInitScope().variable(tf.withInitScope().constant(20)); Add add = tf.math.add(var1, var2); try (Session s = new Session(g)) { - s.run(tf.init()); + s.initialize(); try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { assertEquals(30, t.getInt()); @@ -199,33 +198,6 @@ public void runInit() { } } - @Test - public void runInitByName() { - try (Graph g = new Graph()) { - Ops tf = Ops.create(g); - - Variable var1 = tf.variable(Shape.scalar(), TInt32.class); - tf.initAdd(tf.assign(var1, tf.constant(10))); - Variable var2 = tf.variable(tf.constant(20)); - Add add = tf.math.add(var1, var2); - tf.withName("init_test").init(); - - try (Session s = new Session(g)) { - s.run("init_test"); - - try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { - assertEquals(30, t.getInt()); - } - try { - s.run("wrong_name"); - fail(); - } catch (IllegalArgumentException e) { - // as expected - } - } - } - } - @Test public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); @@ -233,14 +205,21 @@ public void saveAndRestore() throws IOException { Ops tf = Ops.create(g); Variable x = tf.withName("x") - .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + .withInitScope() + .variable( + tf.withInitScope() + .random + .randomUniform(tf.withInitScope().constant(Shape.of(3, 3L)), TFloat32.class)); Variable y = tf.withName("y") - .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); - Init init = tf.init(); + .withInitScope() + .variable( + tf.withInitScope() + .random + .randomUniform(tf.withInitScope().constant(Shape.of(3, 3L)), TFloat32.class)); try (Session s = new Session(g)) { - s.run(init); + s.initialize(); s.save(testFolder.resolve("checkpoint").toString()); GraphDef graphDef = g.toGraphDef(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java index f8eeb84de90..faa53831e32 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java @@ -23,13 +23,12 @@ import java.io.File; import java.nio.file.Paths; - import org.junit.jupiter.api.Test; import org.tensorflow.proto.framework.OpList; /** Unit tests for {@link org.tensorflow.TensorFlow}. */ public class TensorFlowTest { - + @Test public void version() { assertTrue(TensorFlow.version().length() > 0); @@ -51,7 +50,7 @@ public void loadLibrary() { try (Graph g = new Graph()) { // Build a graph with an unrecognized operation. try { - g.opBuilder("MyTest", "MyTest").build(); + g.baseScope().opBuilder("MyTest", "MyTest").build(); fail("should not be able to construct graphs with unregistered ops"); } catch (IllegalArgumentException e) { // expected exception @@ -64,7 +63,7 @@ public void loadLibrary() { assertEquals(opList.getOpList().get(0).getName(), "MyTest"); // Now graph building should succeed. - g.opBuilder("MyTest", "MyTest").build(); + g.baseScope().opBuilder("MyTest", "MyTest").build(); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/RawOpTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/RawOpTest.java index c78864ec052..5d523a986ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/RawOpTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/RawOpTest.java @@ -38,9 +38,9 @@ public void equalsHashcode() { Output array = tf.constant(new int[2]).asOutput(); RawOp test1 = - new RawOp(g.opBuilder("Shape", "shape1").addInput(array).build()) {}; + new RawOp(g.baseScope().opBuilder("Shape", "shape1").addInput(array).build()) {}; RawOp test2 = - new RawOp(g.opBuilder("Shape", "shape2").addInput(array).build()) {}; + new RawOp(g.baseScope().opBuilder("Shape", "shape2").addInput(array).build()) {}; RawOp test3 = new RawOp(test1.operation) {}; // equals() tests diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index 84eabd3da1a..6b37a908f8e 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -1,18 +1,18 @@ /* Copyright 2017-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -207,8 +207,7 @@ static Const create(Scope s, int[] v) { static Const create(Scope s, T value) { return new Const<>( - s.env() - .opBuilder("Const", s.makeOpName("Const")) + s.opBuilder("Const", "Const") .setAttr("dtype", value.dataType()) .setAttr("value", value) .build() @@ -229,12 +228,7 @@ private static final class Mean { static Mean create(Scope s, Output input, Output reductionIndices) { return new Mean<>( - s.env() - .opBuilder("Mean", s.makeOpName("Mean")) - .addInput(input) - .addInput(reductionIndices) - .build() - .output(0)); + s.opBuilder("Mean", "Mean").addInput(input).addInput(reductionIndices).build().output(0)); } Mean(Output o) { @@ -251,8 +245,7 @@ private static final class SquaredDifference { static SquaredDifference create(Scope s, Output x, Output y) { return new SquaredDifference<>( - s.env() - .opBuilder("SquaredDifference", s.makeOpName("SquaredDifference")) + s.opBuilder("SquaredDifference", "SquaredDifference") .addInput(x) .addInput(y) .build() diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java index be4386698fa..3348a418767 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/FunctionTest.java @@ -1,18 +1,18 @@ /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.op.core; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -34,7 +34,6 @@ public class FunctionTest { private static Signature plusFive(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Add output = tf.math.add(input, tf.constant(5.0f)); - Init init = tf.init(); // for native resource management tests return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); } diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java index d5d617158c7..54f153b1988 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -1,4 +1,5 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -62,8 +63,6 @@ static boolean canGenerateOp(OpDef op, ApiDef apiDef) { .startsWith("_"); // TODO do I want this? Some interesting ops like _XlaCompile } - private static final String OP_NAME_FIELD_NAME = "OP_NAME"; - enum RenderMode { DEFAULT, LIST_OPERAND, @@ -318,8 +317,7 @@ void buildClass() { buildInterfaceImpl(); } - if (!isStateSelector) { - // add op name field + if (!isStateSelector) { // add op name field builder.addField( FieldSpec.builder( TypeResolver.STRING, @@ -507,10 +505,7 @@ private void buildFactoryMethods() { Set typeVars = new LinkedHashSet<>(typeParams); body.addStatement( - "$T opBuilder = scope.env().opBuilder($L, scope.makeOpName($S))", - Names.OperationBuilder, - OP_NAME_FIELD, - className); + "$T opBuilder = scope.opBuilder($L, $S)", Names.OperationBuilder, OP_NAME_FIELD, className); List functionArgs = new ArrayList<>(); List iterableFunctionArgs = new ArrayList<>(); @@ -546,8 +541,6 @@ private void buildFactoryMethods() { } } - body.addStatement("opBuilder = scope.apply(opBuilder)"); - // add the required attribute params, and build the default type maps for use in the secondary // factory Map defaultTypes = new HashMap<>(); diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java index 1b1d5cb0fb3..2a33483cb7b 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java @@ -1,18 +1,18 @@ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow.processor.operator; import com.github.javaparser.ast.comments.JavadocComment; @@ -550,6 +550,39 @@ private static TypeSpec buildTopClass(OpsSpec spec) { Names.Scope) .build()); + String initScopeComment = + "

\nInit operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,\n" + + "and are ignored when used as control dependencies.\n" + + "Additionally, this scope ignores any control dependencies.\n" + + "

\nIf an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation."; + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withInitScope") + .addModifiers(Modifier.PUBLIC) + .returns(Names.Ops) + .addStatement("return new $T(scope.withInitScope())", Names.Ops) + .addJavadoc( + "Returns an API that builds init operations. {@link #liftToInitScope(Operand)} will be called for all created operations.\n" + + initScopeComment + + "\n@see #liftToInitScope(Operand)") + .build()); + + TypeVariableName T = TypeVariableName.get("T").withBounds(Names.Operand); + opsBuilder.addMethod( + MethodSpec.methodBuilder("liftToInitScope") + .addTypeVariable(T) + .addModifiers(Modifier.PUBLIC) + .addParameter(T, "op") + .returns(T) + .addStatement("scope.env().registerInitOp(op.op())") + .addStatement("return op") + .addJavadoc( + "Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).\n" + + initScopeComment + + "\n@see ExecutionEnvironment#registerInitOp(Operation)\n" + + "\n@throws IllegalStateException if the op or one of its inputs can't be made an init op.") + .build()); + opsBuilder.addMethod( MethodSpec.methodBuilder("withName") .addModifiers(Modifier.PUBLIC) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index 8ae751823fe..30790d92030 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -29,7 +29,6 @@ import org.tensorflow.framework.data.impl.TensorSliceDataset; import org.tensorflow.framework.data.impl.TextLineDataset; import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TType; @@ -254,7 +253,7 @@ public DatasetIterator makeInitializeableIterator() { *

    *     try (Session session = new Session(graph) {
    *         // Immediately run initializers
-   *         session.run(tf.init());
+   *         session.initialize();
    *     }
    * 
* @@ -264,8 +263,8 @@ public DatasetIterator makeInitializeableIterator() { */ public DatasetIterator makeOneShotIterator() { DatasetIterator iterator = makeInitializeableIterator(); - Op initializer = iterator.makeInitializer(this); - if (tf.scope().env().isGraph()) tf.initAdd(initializer); + // TODO should pass the scope instead + tf.scope().env().registerInitOp(iterator.makeInitializer(this).op()); return iterator; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 0ba94798c19..870f4972c3c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -685,22 +685,22 @@ private Map> build(Shape shape) { Zeros zeros = new Zeros<>(); Operand zero = zeros.call(tf, tf.constant(variableShape), type); if (truePositives == null) { - truePositives = tf.withName(getTruePositivesName()).variable(zero); + truePositives = tf.withName(getTruePositivesName()).withInitScope().variable(zero); initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero)); } if (falsePositives == null) { - falsePositives = tf.withName(getFalsePositivesName()).variable(zero); + falsePositives = tf.withName(getFalsePositivesName()).withInitScope().variable(zero); initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, tf.assign(falsePositives, zero)); } if (trueNegatives == null) { - trueNegatives = tf.withName(getTrueNegativesName()).variable(zero); + trueNegatives = tf.withName(getTrueNegativesName()).withInitScope().variable(zero); initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, tf.assign(trueNegatives, zero)); } if (falseNegatives == null) { - falseNegatives = tf.withName(getFalseNegativesName()).variable(zero); + falseNegatives = tf.withName(getFalseNegativesName()).withInitScope().variable(zero); initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 00ae3727249..6495379c4c4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -96,6 +96,7 @@ private void init() { totalConfusionMatrix = getTF() .withName(totalCMName) + .withInitScope() .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); initializer = getTF() diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index be09e7dd3f6..fa86125dbe7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -88,11 +88,11 @@ private boolean init(Shape shape) { Operand zero = zeros.call(getTF(), getTF().constant(shape), type); if (total == null) { - total = getTF().withName(totalName).variable(zero); + total = getTF().withName(totalName).withInitScope().variable(zero); totalInitializer = getTF().assign(total, zero); } if (count == null) { - count = getTF().withName(countName).variable(zero); + count = getTF().withName(countName).withInitScope().variable(zero); countInitializer = getTF().assign(count, zero); } this.initialized = true; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index f978c0e20da..d81030ebedb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -279,11 +279,11 @@ private void init() { Operand zero = zeros.call(tf, tf.constant(Shape.of(thresholds.length)), type); if (this.truePositives == null) { - this.truePositives = tf.withName(truePositivesName).variable(zero); + this.truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); initializers.add(tf.assign(truePositives, zero)); } if (this.falsePositives == null) { - this.falsePositives = tf.withName(falsePositivesName).variable(zero); + this.falsePositives = tf.withName(falsePositivesName).withInitScope().variable(zero); initializers.add(tf.assign(falsePositives, zero)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 6cb87f5be9e..7cdd01f0c56 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -308,13 +308,13 @@ private void init() { Operand zero = zeros.call(tf, tf.constant(Shape.of(this.thresholds.length)), type); if (truePositives == null) { - truePositives = tf.withName(truePositivesName).variable(zero); + truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); initializers.add(tf.assign(truePositives, zero)); } if (this.falseNegatives == null) { - falseNegatives = tf.withName(falseNegativesName).variable(zero); + falseNegatives = tf.withName(falseNegativesName).withInitScope().variable(zero); initializers.add(tf.assign(falseNegatives, zero)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 4463e1f8213..cbff958fc6f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -119,6 +119,7 @@ private void init() { accumulator = getTF() .withName(getAccumulatorName()) + .withInitScope() .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); initializer = getTF().assign(accumulator, zeros.call(getTF(), getTF().constant(variableShape), type)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 6779b6b1f5a..870579ad636 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -92,22 +92,22 @@ private void init() { if (this.getTruePositives() == null) { - truePositives = tf.withName(truePositivesName).variable(zero); + truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); truePositivesInitializer = tf.assign(truePositives, zero); } if (this.getFalsePositives() == null) { - falsePositives = tf.withName(falsePositivesName).variable(zero); + falsePositives = tf.withName(falsePositivesName).withInitScope().variable(zero); falsePositivesInitializer = tf.assign(falsePositives, zero); } if (this.getTrueNegatives() == null) { - trueNegatives = tf.withName(trueNegativesName).variable(zero); + trueNegatives = tf.withInitScope().withName(trueNegativesName).variable(zero); trueNegativesInitializer = tf.assign(trueNegatives, zero); } if (this.getFalseNegatives() == null) { - falseNegatives = tf.withName(falseNegativesName).variable(zero); + falseNegatives = tf.withInitScope().withName(falseNegativesName).variable(zero); falseNegativesInitializer = tf.assign(falseNegatives, zero); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index e5bab9228b4..19f7584b152 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -15,6 +15,7 @@ */ package org.tensorflow.framework.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -23,8 +24,6 @@ import org.tensorflow.op.train.ApplyAdadelta; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the Adadelta algorithm. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 66a170efcc2..5c51bbc1e4b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -15,16 +15,15 @@ */ package org.tensorflow.framework.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; -import org.tensorflow.op.train.ApplyAdagrad; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyAdagrad; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the Adagrad algorithm. * @@ -43,8 +42,8 @@ public class AdaGrad extends Optimizer { public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f; - private static final ApplyAdagrad.Options[] opts = new ApplyAdagrad.Options[]{ - ApplyAdagrad.updateSlots(true), ApplyAdagrad.useLocking(true)}; + private static final ApplyAdagrad.Options[] opts = + new ApplyAdagrad.Options[] {ApplyAdagrad.updateSlots(true), ApplyAdagrad.useLocking(true)}; private final float learningRate; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index 64473b00f69..62ab8d309c9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -15,20 +15,18 @@ */ package org.tensorflow.framework.optimizers; +import java.util.List; +import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyAdagradDa; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; -import java.util.List; -import java.util.Optional; - /** * Optimizer that implements the Adagrad Dual-Averaging algorithm. * @@ -188,9 +186,11 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdaGradDASlot(v); } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.class); - Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); - graph.addInitializer(globalStepInitializer); + globalStep = + tf.withInitScope() + .withName("adagrad-da-global-step") + .variable(Shape.scalar(), TInt64.class); + tf.withInitScope().assign(globalStep, tf.constant(0L)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index ce581e41397..6cf1dbcc7c5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -15,6 +15,8 @@ */ package org.tensorflow.framework.optimizers; +import java.util.List; +import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -23,16 +25,12 @@ import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyAdam; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; -import java.util.Optional; - /** * Optimizer that implements the Adam algorithm. * @@ -190,12 +188,12 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); - graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); - graph.addInitializer(betaTwoPowerInit); + betaOnePower = + tf.withInitScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class); + tf.withInitScope().assign(betaOnePower, tf.constant(betaOne)); + betaTwoPower = + tf.withInitScope().withName("beta2_power").variable(Shape.scalar(), TFloat32.class); + tf.withInitScope().assign(betaTwoPower, tf.constant(betaTwo)); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index 70b1497c2d8..635c2ecb862 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -1,20 +1,18 @@ package org.tensorflow.framework.optimizers; +import java.util.List; +import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyAdaMax; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; -import java.util.Optional; - /** * Optimizer that implements the Adamax algorithm. * @@ -135,9 +133,9 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamaxSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); - ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); + betaOnePower = + tf.withInitScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class); + tf.withInitScope().assign(betaOnePower, tf.constant(betaOne)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 5d8c1478231..962b64bab8e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -1,5 +1,6 @@ package org.tensorflow.framework.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -8,8 +9,6 @@ import org.tensorflow.op.train.ApplyFtrl; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the FTRL algorithm. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index ca53bd0c7e8..b1f6ac8f4e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -15,6 +15,7 @@ */ package org.tensorflow.framework.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -23,8 +24,6 @@ import org.tensorflow.op.train.ApplyMomentum; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Stochastic gradient descent plus momentum, either nesterov or traditional. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index 5b94b548c0a..f55fb8cdc59 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -1,5 +1,7 @@ package org.tensorflow.framework.optimizers; +import java.util.List; +import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -12,9 +14,6 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; -import java.util.List; -import java.util.Optional; - /** * Nadam Optimizer that implements the NAdam algorithm. * @@ -140,17 +139,16 @@ protected void createSlots(List> variables) { for (Output v : variables) { createNadamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); - ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); + betaOnePower = + tf.withInitScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class); + tf.withInitScope().assign(betaOnePower, tf.constant(betaOne)); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); - ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); + betaTwoPower = + tf.withInitScope().withName("beta2_power").variable(Shape.scalar(), TFloat32.class); + tf.withInitScope().assign(betaTwoPower, tf.constant(betaTwo)); - momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.class); - Assign momentumInit = tf.assign(momentum, tf.constant(1.0F)); - ((Graph) tf.scope().env()).addInitializer(momentumInit); + momentum = tf.withInitScope().withName("momentum").variable(Shape.scalar(), TFloat32.class); + tf.withInitScope().assign(momentum, tf.constant(1.0F)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index ed141831bbe..b1366146836 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -15,6 +15,12 @@ */ package org.tensorflow.framework.optimizers; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -22,14 +28,10 @@ import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TType; -import java.util.*; -import java.util.stream.Collectors; - /** Base class for gradient optimizers. */ public abstract class Optimizer { @@ -41,6 +43,7 @@ public abstract class Optimizer { protected final Graph graph; /** The ops builder for the graph. */ protected final Ops tf; + /** Top level map key is the variable name, lower level map key is the slot name. */ private final Map>> slots; @@ -221,9 +224,10 @@ private Optional> getSlot(String varName, String s protected void createSlot( Output variable, String slotName, Operand initializer) { Variable slot = - tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.type()); - Assign slotInit = tf.assign(slot, initializer); - graph.addInitializer(slotInit); + tf.withInitScope() + .withName(createName(variable, slotName)) + .variable(variable.shape(), variable.type()); + tf.withInitScope().assign(slot, initializer); String varName = variable.op().name(); Map> variables = slots.computeIfAbsent(slotName, (k) -> new HashMap<>()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 79ced52dc08..0d4daf748d4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -15,6 +15,7 @@ */ package org.tensorflow.framework.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -24,8 +25,6 @@ import org.tensorflow.op.train.ApplyRmsProp; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the RMSProp algorithm. * diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 882a64ba54d..1f8503829b7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -15,19 +15,18 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.types.family.TType; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.tensorflow.types.family.TType; public class DatasetIteratorTest extends DatasetTestBase { @@ -48,15 +47,15 @@ public void testGraphIteration() { Operand y = components.get(1); try (Session session = new Session(graph)) { - session.run(tf.init()); + session.initialize(); int batches = 0; while (true) { try { List outputs = session.runner().fetch(x).fetch(y).run(); - try (TInt32 xBatch = (TInt32)outputs.get(0); - TInt32 yBatch = (TInt32)outputs.get(1)) { + try (TInt32 xBatch = (TInt32) outputs.get(0); + TInt32 yBatch = (TInt32) outputs.get(1)) { assertEquals(testMatrix1.get(batches), xBatch); assertEquals(testMatrix2.get(batches), yBatch); batches++; @@ -81,8 +80,8 @@ public void testEagerIteration() { Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); int count = 0; for (List> outputs : dataset) { - try (TInt32 batch1 = (TInt32)outputs.get(0).asTensor(); - TInt32 batch2 = (TInt32)outputs.get(1).asTensor()) { + try (TInt32 batch1 = (TInt32) outputs.get(0).asTensor(); + TInt32 batch2 = (TInt32) outputs.get(1).asTensor()) { assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5f203427563..afa38e04ee8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -15,22 +15,21 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.types.family.TType; import org.tensorflow.exceptions.TFOutOfRangeException; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.tensorflow.types.family.TType; public class MapDatasetTest extends DatasetTestBase { IntNdArray mapped1; @@ -73,15 +72,15 @@ public void testGraphIteration() { Operand y = components.get(1); try (Session session = new Session(graph)) { - session.run(tf.init()); + session.initialize(); int batches = 0; while (true) { try { List outputs = session.runner().fetch(X).fetch(y).run(); - try (TInt32 XBatch = (TInt32)outputs.get(0); - TInt32 yBatch = (TInt32)outputs.get(1)) { + try (TInt32 XBatch = (TInt32) outputs.get(0); + TInt32 yBatch = (TInt32) outputs.get(1)) { assertEquals(mapped1.get(batches), XBatch); assertEquals(mapped2.get(batches), yBatch); @@ -113,8 +112,8 @@ public void testEagerIteration() { int count = 0; for (List> outputs : dataset) { - try (TInt32 XBatch = (TInt32)outputs.get(0).asTensor(); - TInt32 yBatch = (TInt32)outputs.get(1).asTensor()) { + try (TInt32 XBatch = (TInt32) outputs.get(0).asTensor(); + TInt32 yBatch = (TInt32) outputs.get(1).asTensor()) { assertEquals(mapped1.get(count), XBatch); assertEquals(mapped2.get(count), yBatch); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index bd693da1312..ae40074f3f6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -14,6 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -23,12 +29,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class AUCTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -47,7 +47,7 @@ public void testValueIsIdempotent() { Operand yTrue = tf.constant(trueArray); AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, null); @@ -77,16 +77,13 @@ public void testCumulative() { Operand yTrue = tf.constant(trueArray); AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); assertNull(instance.getTruePositives()); assertNull(instance.getFalsePositives()); assertNull(instance.getTrueNegatives()); assertNull(instance.getFalseNegatives()); - - - for (int i = 0; i < 3; i++) { Op update = instance.updateState(yTrue, yPred, null); session.run(update); @@ -118,7 +115,6 @@ public void basicTestSampleWeight() { float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); - Operand yPred = tf.constant(new float[] {0, 0, 1, 1}); Operand yTrue = tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}); Operand sampleWeights = tf.constant(new float[] {1, 0, 0, 1}); @@ -136,7 +132,7 @@ public void testUnweightedAllCorrect() { Ops tf = session.getTF(); Operand yTrue = cast(tf, tf.constant(this.trueArray), TFloat32.class); AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yTrue, null); session.run(update); @@ -153,7 +149,7 @@ public void testUnweighted() { Operand yPred = tf.constant(this.predArray); Operand yTrue = tf.constant(this.trueArray); AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, null); session.run(update); Operand result = instance.result(); @@ -172,7 +168,7 @@ public void testManualThresholds() { AUC instance = new AUC<>(tf, new float[] {0.5f}, 1001L, TFloat32.class); float[] expectedThresholds = new float[] {-AUC.EPSILON, 0.5f, 1 + AUC.EPSILON}; assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, null); session.run(update); Operand result = instance.result(); @@ -191,7 +187,7 @@ public void testWeightedRocInterpolation() { Operand sampleWights = tf.constant(this.sampleWeight); AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, sampleWights); session.run(update); Operand result = instance.result(); @@ -217,7 +213,7 @@ public void testWeightedRocMajoring() { AUCSummationMethod.MAJORING, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, sampleWights); session.run(update); Operand result = instance.result(); @@ -243,7 +239,7 @@ public void testWeightedRocMinoring() { AUCSummationMethod.MINORING, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, sampleWights); session.run(update); Operand result = instance.result(); @@ -269,7 +265,7 @@ public void testWeightedPrMajoring() { AUCSummationMethod.MAJORING, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, sampleWights); session.run(update); Operand result = instance.result(); @@ -294,7 +290,7 @@ public void testWeightedPrMinoring() { AUCSummationMethod.MINORING, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, sampleWights); session.run(update); Operand result = instance.result(); @@ -313,7 +309,7 @@ public void testWeightedPrInterpolation() { AUC instance = new AUC<>(tf, this.numThresholds, AUCCurve.PR, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(yTrue, yPred, sampleWights); session.run(update); Operand result = instance.result(); @@ -362,7 +358,7 @@ public void testExtraDims() { Operand labels = tf.constant(labelArray); AUC instance = new AUC<>(tf, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(labels, logits, null); session.run(update); Operand result = instance.result(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java index fc08455d1c7..686e6371bc0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java @@ -98,7 +98,7 @@ public void testZeroAndNonZeroEntries() { Operand labels = tf.constant(new int[] {1}); MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(labels, predictions, null); session.run(update); Operand result = instance.result(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java index 0bb9392b8b0..ce5d87869ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -22,8 +24,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class MeanRelativeErrorTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -38,8 +38,8 @@ public void testUnweighted() { MeanRelativeError instance = new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.initialize(); session.run(instance.resetStates()); - session.run(tf.init()); Op update = instance.updateState(labels, predictions, null); session.run(update); Operand result = instance.result(); @@ -63,8 +63,8 @@ public void testWeighted() { MeanRelativeError instance = new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.initialize(); session.run(instance.resetStates()); - session.run(tf.init()); Op update = instance.updateState(labels, predictions, sampleWeight); session.run(update); Operand result = instance.result(); @@ -87,8 +87,8 @@ public void testZeroNormalizer() { MeanRelativeError instance = new MeanRelativeError<>( tf, cast(tf, tf.zerosLike(labels), TFloat32.class), 1001L, TFloat32.class); + session.initialize(); session.run(instance.resetStates()); - session.run(tf.init()); Op update = instance.updateState(labels, predictions, null); session.run(update); Operand result = instance.result(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java index ce473bbdf34..3fb11f86b45 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java @@ -32,7 +32,7 @@ public void testUnweighted() { Ops tf = session.getTF(); Operand values = tf.constant(new long[] {100, 40}); MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); - session.run(tf.init()); + session.initialize(); Op update = instance.updateState(values, null); session.run(update); Operand result = instance.result(); @@ -54,7 +54,7 @@ public void testWeighted() { Ops tf = session.getTF(); Operand values = tf.constant(new long[] {100, 30}); MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); - session.run(tf.init()); + session.initialize(); // check scalar weight Op update = instance.updateState(values, tf.constant(0.5f)); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 86a3200ac81..9b85b78a694 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -14,6 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR; +import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR_UPDATE; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.optimizers.Optimizer.GradAndVar; @@ -21,19 +27,11 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR; -import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR_UPDATE; - /** Test cases for AdaDelta Optimizer */ public class AdaDeltaTest { @@ -89,11 +87,13 @@ public void testBasic() { float[] var1Init = {3.0F, 4.0F}; float[] fgrads = {grad, grad}; Shape shape = Shape.of(var0Init.length); - Variable var0 = tf.withName("var0").variable(shape, TFloat32.class); - Variable var1 = tf.withName("var1").variable(shape, TFloat32.class); + Variable var0 = + tf.withInitScope().withName("var0").variable(shape, TFloat32.class); + Variable var1 = + tf.withInitScope().withName("var1").variable(shape, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant cgrads = tf.constant(fgrads); @@ -129,12 +129,8 @@ public void testBasic() { slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); assertEquals(slotUpdates[1].shape(), var1.shape()); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index d5b2657a4fc..b48359fc989 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -14,21 +14,19 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - /** Test cases for AdaGradDA Optimizer */ public class AdaGradDATest { @@ -67,18 +65,12 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); - /* initialize the local variables */ - - session.run(var0Initializer); - session.run(var1Initializer); - - /* build the GradsAnvVars */ List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -87,7 +79,7 @@ public void testBasic() { Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradDATest"); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index 8182dc5b00d..70a185a8928 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.AdaGrad.ACCUMULATOR; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.ND; @@ -23,18 +28,11 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.AdaGrad.ACCUMULATOR; - /** Test cases for AdaGrad Optimizer */ public class AdaGradTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -82,14 +80,12 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); - - /* build the GradsAnvVars */ List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -104,12 +100,8 @@ public void testBasic() { accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); assertEquals(accumulatorSlots[1].shape(), var1.shape()); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 49154882a0f..00dfeecce51 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -14,6 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.ND; @@ -23,19 +29,11 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; -import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; - /** Test cases for Adam Optimizer */ public class AdamTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -82,25 +80,17 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - - /* build the GradsAnvVars */ List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validate the shapes of the slots */ @@ -122,7 +112,7 @@ public void testBasic() { assertEquals(secondMomentSlots[1].shape(), var1.shape()); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); @@ -140,21 +130,11 @@ public void testBasic() { }; try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta1_power") - .run() - .get(0)) { + (TFloat32) session.getGraphSession().runner().fetch("beta1_power").run().get(0)) { result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); } try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta2_power") - .run() - .get(0)) { + (TFloat32) session.getGraphSession().runner().fetch("beta2_power").run().get(0)) { result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 60c17674dfe..7cf645bb935 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adamax.*; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.ND; @@ -23,18 +28,11 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Adamax.*; - /** Test cases for Adamax Optimizer */ public class AdamaxTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -103,17 +101,12 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /* build the GradsAnvVars */ List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -138,23 +131,15 @@ public void testBasic() { assertEquals(secondMomentSlots[1].shape(), var1.shape()); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); session.setEpsilon(epsilon1); for (int step = 0; step < numSteps; step++) { // Test powers final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta1_power") - .run() - .get(0)) { + (TFloat32) session.getGraphSession().runner().fetch("beta1_power").run().get(0)) { result.scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java index 7698d76f957..ac624424184 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.TestSession; @@ -26,11 +30,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - /** Test cases for Ftrl Optimizer */ public class FtrlTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -79,8 +78,8 @@ public void testFtrlWithL1L2L2Shrinkage() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); @@ -105,12 +104,8 @@ public void testFtrlWithL1L2L2Shrinkage() { Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest"); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); @@ -175,7 +170,7 @@ public void testFtrlWithL1() { session.run(var1Initializer); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); @@ -241,7 +236,7 @@ public void testFtrlWithL1L2() { session.run(var1Initializer); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); @@ -297,7 +292,7 @@ public void doTestFtrlwithoutRegularization() { session.run(var1Initializer); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 835ed8fdcaa..909fd53ca27 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -20,9 +20,7 @@ import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Add; @@ -80,8 +78,8 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); @@ -94,12 +92,8 @@ public void testBasic() { GradientDescent instance = new GradientDescent(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); @@ -127,7 +121,6 @@ public void testDeterminism() { .build(); GraphDef def; - String initName; String trainName; String lossName; @@ -144,17 +137,19 @@ public void testDeterminism() { // Fully connected layer Variable fcWeights = - tf.variable(initializer.call(tf, tf.array(20L, 200L), TFloat32.class)); + tf.withInitScope().variable(initializer.call(tf, tf.array(20L, 200L), TFloat32.class)); fcWeightName = fcWeights.op().name(); - Variable fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f))); + Variable fcBiases = + tf.withInitScope().variable(tf.fill(tf.array(200), tf.constant(0.1f))); fcBiasName = fcBiases.op().name(); Relu relu = tf.nn.relu(tf.math.add(tf.linalg.matMul(input, fcWeights), fcBiases)); // Output layer Variable outputWeights = - tf.variable(initializer.call(tf, tf.array(200L, 2L), TFloat32.class)); + tf.withInitScope().variable(initializer.call(tf, tf.array(200L, 2L), TFloat32.class)); outputWeightName = outputWeights.op().name(); - Variable outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f))); + Variable outputBiases = + tf.withInitScope().variable(tf.fill(tf.array(2L), tf.constant(0.1f))); outputBiasName = outputBiases.op().name(); Add output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases); @@ -170,10 +165,6 @@ public void testDeterminism() { Op trainingOp = gd.minimize(loss); trainName = trainingOp.op().name(); - // Create the init op - Init init = tf.init(); - initName = init.op().name(); - def = g.toGraphDef(); } @@ -196,7 +187,7 @@ public void testDeterminism() { try (Graph g = new Graph(); Session s = new Session(g, config)) { g.importGraphDef(def); - s.run(initName); + s.initialize(); initialized.add( s.runner() diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index 80a8d9b5fd6..8bd7dfc2d78 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.TestSession; @@ -26,12 +31,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; - /** Test cases for SGD Optimizer */ public class MomentumTest { @@ -80,8 +79,8 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); @@ -94,12 +93,8 @@ public void testBasic() { Momentum instance = new Momentum(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); @@ -157,7 +152,7 @@ public void testMomentum() { session.run(var1Initializer); /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index 849f2fbfec1..e18d870107f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.ND; @@ -23,17 +27,11 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - /** Test cases for Nadam Optimizer */ public class NadamTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -104,8 +102,8 @@ public void testBasic() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); @@ -134,12 +132,8 @@ public void testBasic() { secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].shape(), var1.shape()); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); session.setEpsilon(epsilon1); @@ -147,12 +141,7 @@ public void testBasic() { session.evaluate(var1Init, var1); try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0)) { + (TFloat32) session.getGraphSession().runner().fetch("momentum").run().get(0)) { result.scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; @@ -166,12 +155,7 @@ public void testBasic() { momentum = momentum * mut; try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0)) { + (TFloat32) session.getGraphSession().runner().fetch("momentum").run().get(0)) { result.scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 3b002cd1dbe..53d7cceae0d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; +import static org.tensorflow.framework.optimizers.RMSProp.*; + +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.*; import org.tensorflow.Graph; import org.tensorflow.framework.utils.ND; @@ -23,17 +27,11 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.tensorflow.framework.optimizers.RMSProp.*; - /** Test cases for RMSProp Optimizer */ public class RMSPropTest { final int VAR_T = 0; @@ -90,8 +88,8 @@ public void testDense() { Variable var0 = tf.withName("var0").variable(shape0, TFloat32.class); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.class); - Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + tf.withInitScope().assign(var0, tf.constant(var0Init)); + tf.withInitScope().assign(var1, tf.constant(var1Init)); Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); @@ -112,12 +110,8 @@ public void testDense() { Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - /* initialize the accumulators */ - session.run(tf.init()); + session.initialize(); /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 43c0642939e..35efa292811 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -14,6 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.utils; +import static org.junit.jupiter.api.Assertions.*; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; @@ -23,42 +29,28 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; - -import static org.junit.jupiter.api.Assertions.*; - -/** - * Graph Mode Test Session - */ +/** Graph Mode Test Session */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; + private boolean hasInited = false; - /** - * Create a Graph mode test session. - */ + /** Create a Graph mode test session. */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); tf = Ops.create(graph).withName("test"); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public Ops getTF() { return tf; } - /** - * Get the Graph object that is represented by this Test Session - */ + /** Get the Graph object that is represented by this Test Session */ public Graph getGraph() { return graph; } @@ -72,134 +64,122 @@ public Session getSession() { return session; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void close() { session.close(); graph.close(); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public boolean isEager() { return false; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public Session getGraphSession() { return this.session; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public EagerSession getEagerSession() { return null; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void initialize() { - graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); + public synchronized void initialize() { + session.forceInitialize(); + hasInited = true; } - /** - * {@inheritDoc} - */ + private synchronized void initIfNeeded() { + if (!hasInited) { + session.initialize(); + hasInited = true; + } + } + + /** {@inheritDoc} */ @Override public void run(Op op) { + initIfNeeded(); session.run(op); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(double expected, Operand input) { + initIfNeeded(); Class inputType = input.type(); if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat32 result = (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } } else if (inputType == TFloat64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat64 result = (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } else { @@ -207,11 +187,10 @@ public void evaluate(double expected, Operand input) { } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(Number[] expected, Output input) { + initIfNeeded(); int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); if (size != Shape.UNKNOWN_SIZE) { assertEquals( @@ -225,15 +204,14 @@ public void evaluate(Number[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat32 result = (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -245,15 +223,14 @@ public void evaluate(Number[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat64 result = (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -264,16 +241,14 @@ public void evaluate(Number[] expected, Output input) { } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); @@ -281,16 +256,14 @@ public void evaluate(Number[] expected, Output input) { } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); @@ -298,16 +271,14 @@ public void evaluate(Number[] expected, Output input) { } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); @@ -317,25 +288,23 @@ public void evaluate(Number[] expected, Output input) { } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(FloatNdArray expected, Output input) { + initIfNeeded(); Class inputType = input.type(); if (inputType == TFloat32.class) { AtomicLong index = new AtomicLong(); if (debug) { try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat32 result = (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -347,15 +316,14 @@ public void evaluate(FloatNdArray expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat64 result = (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -366,16 +334,14 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -384,16 +350,14 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -402,16 +366,14 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach( @@ -422,11 +384,10 @@ public void evaluate(FloatNdArray expected, Output input) { } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(String[] expected, Output input) { + initIfNeeded(); int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); if (size != Shape.UNKNOWN_SIZE) { assertEquals( @@ -437,27 +398,22 @@ public void evaluate(String[] expected, Output input) { } AtomicInteger index = new AtomicInteger(); if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } } index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(Boolean[] expected, Output input) { + initIfNeeded(); int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -465,31 +421,26 @@ public void evaluate(Boolean[] expected, Output input) { () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); AtomicInteger index = new AtomicInteger(); if (debug) { - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TBool result = (TBool) this.getGraphSession().runner().fetch(input).run().get(0)) { result .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); } } index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + try (TBool result = (TBool) this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(Output expected, Output input) { + initIfNeeded(); assert input.shape().equals(expected.shape()) : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); Class inputType = input.type(); if (!inputType.equals(expected.type())) { @@ -503,12 +454,11 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0); TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + System.out.printf("0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); } else { result .scalars() @@ -523,30 +473,27 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); + try (TFloat32 result = (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0); TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } } else if (inputType == TFloat64.class) { final Output finalExpected = (Output) expected; if (debug) { try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0); TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); + System.out.printf("0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); } else { result .scalars() @@ -561,30 +508,27 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); + try (TFloat64 result = (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0); TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); } } } else if (inputType == TFloat16.class) { final Output finalExpected = (Output) expected; if (debug) { try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat16) this.getGraphSession().runner().fetch(input).run().get(0); TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat16) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + System.out.printf("0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); } else { result .scalars() @@ -599,30 +543,26 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); + try (TFloat16 result = (TFloat16) this.getGraphSession().runner().fetch(input).run().get(0); TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat16) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } } else if (inputType == TInt32.class) { final Output finalExpected = (Output) expected; if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0); TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); + System.out.printf("0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); } else { result .scalars() @@ -630,15 +570,16 @@ public void evaluate(Output expected, Output input) { (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); + index.getAndIncrement(), + finalExpected.asTensor().getInt(idx), + f.getInt())); } } } index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0); TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getInt(), result.getInt(), epsilon); } else { @@ -651,13 +592,11 @@ public void evaluate(Output expected, Output input) { } else if (inputType == TInt64.class) { final Output finalExpected = (Output) expected; if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0); TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); + System.out.printf("0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); } else { result .scalars() @@ -672,30 +611,26 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0); TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getLong(), result.getLong(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); } } } else if (inputType == TUint8.class) { final Output finalExpected = (Output) expected; if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0); TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); + System.out.printf("0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); } else { result .scalars() @@ -710,30 +645,26 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0); TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getByte(), result.getByte(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); } } } else if (inputType == TBool.class) { final Output finalExpected = (Output) expected; if (debug) { - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); + try (TBool result = (TBool) this.getGraphSession().runner().fetch(input).run().get(0); TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TBool) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); + System.out.printf("0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); } else { result .scalars() @@ -748,10 +679,9 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); + try (TBool result = (TBool) this.getGraphSession().runner().fetch(input).run().get(0); TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TBool) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getBoolean(), result.getBoolean()); } else { @@ -764,13 +694,11 @@ public void evaluate(Output expected, Output input) { } else if (inputType == TString.class) { final Output finalExpected = (Output) expected; if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0); TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); + System.out.printf("0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); } else { result .scalars() @@ -785,10 +713,9 @@ public void evaluate(Output expected, Output input) { } } index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0); TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertEquals(expectedResult.getObject(), result.getObject()); } else { @@ -803,20 +730,17 @@ public void evaluate(Output expected, Output input) { } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluateString(Output input, Predicate predicate) { + initIfNeeded(); boolean isScalar = input.shape().equals(Shape.scalar()); AtomicInteger index = new AtomicInteger(); if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %s\n", - predicate.test(result.getObject()), result.getObject()); + "0). %b <==> %s\n", predicate.test(result.getObject()), result.getObject()); } else { result .scalars() @@ -829,34 +753,29 @@ public void evaluateString(Output input, Predicate predicate) { } } index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertTrue(predicate.test(result.getObject())); } else { - result - .scalars() - .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); + result.scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void evaluate(Output input, Predicate predicate) { + initIfNeeded(); AtomicInteger index = new AtomicInteger(); Class inputType = input.type(); boolean isScalar = input.shape().equals(Shape.scalar()); if (inputType == TFloat32.class) { if (debug) { try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getFloat()), result.getFloat()); + "0). %b <==> %f\n", predicate.test(result.getFloat()), result.getFloat()); } else { result .scalars() @@ -869,8 +788,7 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat32 result = (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertTrue(predicate.test(result.getFloat())); } else { @@ -882,11 +800,10 @@ public void evaluate(Output input, Predicate predic } else if (inputType == TFloat64.class) { if (debug) { try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getDouble()), result.getDouble()); + "0). %b <==> %f\n", predicate.test(result.getDouble()), result.getDouble()); } else { result .scalars() @@ -899,8 +816,7 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat64 result = (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertTrue(predicate.test(result.getDouble())); } else { @@ -911,11 +827,9 @@ public void evaluate(Output input, Predicate predic } } else if (inputType == TInt32.class) { if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); + System.out.printf("0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); } else { result .scalars() @@ -928,24 +842,19 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertTrue(predicate.test(result.getInt())); } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); + result.scalars().forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); } } } else if (inputType == TInt64.class) { if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getLong()), result.getLong()); + "0). %b <==> %d\n", predicate.test(result.getLong()), result.getLong()); } else { result .scalars() @@ -958,24 +867,19 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertTrue(predicate.test(result.getLong())); } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); + result.scalars().forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); } } } else if (inputType == TUint8.class) { if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getByte()), result.getByte()); + "0). %b <==> %d\n", predicate.test(result.getByte()), result.getByte()); } else { result .scalars() @@ -988,14 +892,11 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { assertTrue(predicate.test(result.getByte())); } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); + result.scalars().forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); } } } else { @@ -1003,18 +904,16 @@ public void evaluate(Output input, Predicate predic } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void print(PrintWriter writer, Output input) { + initIfNeeded(); boolean isScalar = input.shape().size() == 1; Class inputType = input.type(); if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat32 result = (TFloat32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); } else { @@ -1027,11 +926,9 @@ public void print(PrintWriter writer, Output input) { } else if (inputType == TFloat64.class) { AtomicInteger index = new AtomicInteger(); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TFloat64 result = (TFloat64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf( - "%d). %f\n", index.getAndIncrement(), result.getDouble()); + writer.printf("%d). %f\n", index.getAndIncrement(), result.getDouble()); } else { result .scalars() @@ -1042,11 +939,9 @@ public void print(PrintWriter writer, Output input) { } else if (inputType == TInt32.class) { AtomicInteger index = new AtomicInteger(); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt32 result = (TInt32) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(),result.getInt()); + writer.printf("%d). %d\n", index.getAndIncrement(), result.getInt()); } else { result .scalars() @@ -1057,11 +952,9 @@ public void print(PrintWriter writer, Output input) { } else if (inputType == TInt64.class) { AtomicInteger index = new AtomicInteger(); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TInt64 result = (TInt64) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), result.getLong()); + writer.printf("%d). %d\n", index.getAndIncrement(), result.getLong()); } else { result .scalars() @@ -1072,11 +965,9 @@ public void print(PrintWriter writer, Output input) { } else if (inputType == TUint8.class) { AtomicInteger index = new AtomicInteger(); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TUint8 result = (TUint8) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf( - "%d). %x\n", index.getAndIncrement(), result.getByte()); + writer.printf("%d). %x\n", index.getAndIncrement(), result.getByte()); } else { result .scalars() @@ -1087,11 +978,9 @@ public void print(PrintWriter writer, Output input) { } else if (inputType == TBool.class) { AtomicInteger index = new AtomicInteger(); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TBool result = (TBool) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf( - "%d). %b\n", index.getAndIncrement(), result.getBoolean()); + writer.printf("%d). %b\n", index.getAndIncrement(), result.getBoolean()); } else { result .scalars() @@ -1102,11 +991,9 @@ public void print(PrintWriter writer, Output input) { } else if (inputType == TString.class) { AtomicInteger index = new AtomicInteger(); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + try (TString result = (TString) this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf( - "%d). %s\n", index.getAndIncrement(), result.getObject()); + writer.printf("%d). %s\n", index.getAndIncrement(), result.getObject()); } else { result .scalars() diff --git a/tensorflow-framework/tensorflow-data.md b/tensorflow-framework/tensorflow-data.md index df0ec190b9f..0978129b77b 100644 --- a/tensorflow-framework/tensorflow-data.md +++ b/tensorflow-framework/tensorflow-data.md @@ -173,7 +173,7 @@ try (Graph graph = new Graph()) { // instantiate graph-mode session try (Session session = new Session(graph)) { // Run graph initializers - session.run(tf.init()); + session.initialize(); // Iterate over dataset elements while (true) {