diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LeakyRelu.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LeakyRelu.pbtxt index 31a4f01167b..3573bff4fa8 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LeakyRelu.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LeakyRelu.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "LeakyRelu" + visibility: VISIBLE endpoint { name: "nn.LeakyRelu" } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 33caf02d890..81a24514a08 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -59,6 +59,7 @@ import org.tensorflow.op.nn.FusedResizeAndPadConv2d; import org.tensorflow.op.nn.InTopK; import org.tensorflow.op.nn.L2Loss; +import org.tensorflow.op.nn.LeakyRelu; import org.tensorflow.op.nn.LearnedUnigramCandidateSampler; import org.tensorflow.op.nn.LocalResponseNormalization; import org.tensorflow.op.nn.LogSoftmax; @@ -1226,6 +1227,19 @@ public L2Loss l2Loss(Operand t) { return L2Loss.create(scope, t); } + /** + * Computes rectified linear: `max(features, features * alpha)`. + * + * @param data type for {@code activations()} output + * @param features + * @param options carries optional attributes values + * @return a new instance of LeakyRelu + */ + public LeakyRelu leakyRelu(Operand features, + LeakyRelu.Options... options) { + return LeakyRelu.create(scope, features, options); + } + /** * Generates labels for candidate sampling with a learned unigram distribution. *

diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java index 230b4238047..8ca3f540cad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java @@ -33,6 +33,7 @@ * * @param data type for {@code activations()} output */ +@Operator(group = "nn") public final class LeakyRelu extends RawOp implements Operand { /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index faf698a6fe5..50f6ea49b06 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -30,7 +30,7 @@ import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TFloating; /** * Brain 16-bit float tensor type. @@ -48,7 +48,7 @@ *

Note that some CPUs support the bfloat16 format natively, which can result in faster * computation compared to {@link TFloat16} when GPUs are not used. */ -public interface TBfloat16 extends FloatNdArray, TNumber { +public interface TBfloat16 extends FloatNdArray, TFloating { /** readable-name for the data type */ static final String NAME = "BFLOAT16"; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index 6ce463ff2c0..0cd441a1ff1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -30,7 +30,7 @@ import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TFloating; /** * IEEE-754 half-precision 16-bit float tensor type. @@ -45,7 +45,7 @@ * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link * TBfloat16} tensor type might be a better option. */ -public interface TFloat16 extends FloatNdArray, TNumber { +public interface TFloat16 extends FloatNdArray, TFloating { /** readable-name for the data type */ static final String NAME = "FLOAT16"; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 968b3f2a539..571ec118ddc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -29,10 +29,10 @@ import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TFloating; /** IEEE-754 single-precision 32-bit float tensor type. */ -public interface TFloat32 extends FloatNdArray, TNumber { +public interface TFloat32 extends FloatNdArray, TFloating { /** readable-name for the data type */ static final String NAME = "FLOAT"; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 9cf5fdaaeaa..5d2744c4b3c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -29,10 +29,11 @@ import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; -import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TFloating; + /** IEEE-754 double-precision 64-bit float tensor type. */ -public interface TFloat64 extends DoubleNdArray, TNumber { +public interface TFloat64 extends DoubleNdArray, TFloating { /** readable-name for the data type */ static final String NAME = "DOUBLE"; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java new file mode 100644 index 00000000000..92deaffdc68 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java @@ -0,0 +1,19 @@ +package org.tensorflow.types.family; + +/** + * Marker interface for floating point tensor types. + * + *

Operations that only accepts floating point values as some of their operands enforce that the tensor + * types for these operands to be bound to this interface. For example: + * + *

{@code
+ * TFloat32 tensor1 = TFloat32.vectorOf(1, 2, 3);
+ * TBool tensor2 = TBool.vectorOf(true, false, true);
+ *
+ * Ops tf = Ops.create();
+ * Exponential exp = new Exponential<>(tf);
+ * exp.call(tf.constant(tensor1));  // OK
+ * exp.call(tf.constant(tensor2));  // Compilation failure
+ * }
+ */ +public interface TFloating extends TNumber {} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java new file mode 100644 index 00000000000..e1482a51a8a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Abstract base class for Activations + * + *

Note: The {@link #tf} attribute must be set prior to invoking the call method. See + * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. + * + * @param the data type of the activation + */ +public abstract class Activation { + + /** The TensorFlow Ops */ + protected Ops tf; + + /** + * Creates the abstract class for an Activation + * + * @param tf the TensorFlow Ops + */ + protected Activation(Ops tf) { + this.tf = tf; + } + + /** + * Sets the TensorFlow Ops + * + * @param tf the TensorFlow Ops + */ + protected void setTF(Ops tf) { + this.tf = tf; + } + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + protected Ops getTF() { + return this.tf; + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + public abstract Operand call(Operand input); +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java new file mode 100644 index 00000000000..ae3d7e8c896 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; + +/** + * Exponential linear unit. + * + *

The exponential linear unit (ELU) with alpha > 0 is: + * + *

x if x > 0 and alpha * (exp(x) - + * 1) if x < 0. + * + *

The ELU hyperparameter alpha controls the value to which an ELU saturates for + * negative net inputs. ELUs diminish the vanishing gradient effect. + * + *

ELUs have negative values which pushes the mean of the activations closer to zero. Mean + * activations that are closer to zero enable faster learning as they bring the gradient closer to + * the natural gradient. ELUs saturate to a negative value when the argument gets smaller. + * Saturation means a small derivative which decreases the variation and the information that is + * propagated to the next layer. + * + *

Example Usage: + * + *

+ *     Operand<TFloat32> input = ...;
+ *     ELU<TFloat32> elu = new ELU<>(tf, 2.0f);
+ *     Operand<TFloat32> result = elu.call(input);
+ * 
+ * + * @param the data type of the activation + * @see Clevert et al, 2016, Fast and Accurate Deep + * Network Learning by Exponential Linear Units (ELUs) + */ +public class ELU extends Activation { + + private static final double ALPHA_DEFAULT = 1.0; + + /** A scalar, slope of negative section. */ + private final double alpha; + + /** + * Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. + * + * @param tf the TensorFlow Ops + */ + public ELU(Ops tf) { + this(tf, ALPHA_DEFAULT); + } + + /** + * Creates a new ELU + * + * @param tf the TensorFlow Ops + * @param alpha A scalar, slope of negative section. It controls the value to which an ELU + * saturates for negative net inputs. + */ + public ELU(Ops tf, double alpha) { + super(tf); + this.alpha = alpha; + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + + Operand result = tf.nn.elu(input); + if (alpha == 1.0) return result; + else { + DataType dataType = input.asOutput().dataType(); + Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType)); + Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType)); + return tf.select(cond, result, y); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java new file mode 100644 index 00000000000..d5fdff36c61 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Exponential activation function. + * + *

For example: + * + *

+ *   Operand<TFloat32> input = tf.constant(
+ *          new float[] {-3.0f,-1.0f, 0.0f,1.0f,3.0f});
+ *   Exponential<TFloat32> exp = new Exponential<>(tf);
+ *   Operand<TFloat32> result = exp.call(input);
+ *   // result is [0.04978707f,  0.36787945f,  1.f,  2.7182817f, 20.085537f]
+ * 
+ * + * @param the data type of the activation + */ +public class Exponential extends Activation { + + /** + * Creates an Exponential activation. + * + * @param tf the TensorFlow Ops + */ + public Exponential(Ops tf) { + super(tf); + } + + /** + * Calculates the Exponential activation. + * + * @param input the input tensor + * @return an Operand for the exponential activation: exp(x). + */ + @Override + public Operand call(Operand input) { + return tf.math.exp(input); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java new file mode 100644 index 00000000000..a486cbdc601 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Hard sigmoid activation. + * + *

A faster approximation of the sigmoid activation. + * + *

Defined as: + * + *

    + *
  • if x < -2.5: return 0 + *
  • if x > 2.5: return 1 + *
  • if -2.5 <= x <= 2.5: return 0.2 * x + 0.5 + *
+ * + *

For example: + * + *

+ *     Operand<TFloat32> input = tf.constant(
+ *              new float[] {-3.0f,-1.0f, 0.0f,1.0f,3.0f});
+ *     HardSigmoid<TFloat32> hardSigmoid = new HardSigmoid<>(tf);
+ *     Operand<TFloat32> result = hardSigmoid.call(input);
+ *     // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f]
+ * 
+ * + * @param the data type of the result + */ +public class HardSigmoid extends Activation { + + /** + * Creates Hard sigmoid activation. + * + * @param tf the TensorFlow Ops + */ + public HardSigmoid(Ops tf) { + super(tf); + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + DataType dataType = input.asOutput().dataType(); + Operand point2 = tf.dtypes.cast(tf.constant(0.2), dataType); + Operand point5 = tf.dtypes.cast(tf.constant(0.5), dataType); + + Operand x = tf.math.add(tf.math.mul(input, point2), point5); + return tf.clipByValue( + x, tf.dtypes.cast(tf.constant(0), dataType), tf.dtypes.cast(tf.constant(1), dataType)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java new file mode 100644 index 00000000000..d907397995d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Linear activation function (pass-through). + * + *

The linear activation returns its input. It is also known as the Identity activation function.

+ * + *

For example: + * + *

+ *    Operand<TFloat32> input = tf.constant(
+ *              new float[] {-3.0f,-1.0f, 0.0f,1.0f,3.0f});
+ *    Linear<TFloat32> linear = new Linear<>(tf);
+ *    Operand<TFloat32> result = linear.call(input);
+ *    // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f]
+ * 
+ */ +public class Linear extends Activation { + + /** + * Creates a linear activation. + * + * @param tf the TensorFlow Ops + */ + public Linear(Ops tf) { + super(tf); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + return input; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java new file mode 100644 index 00000000000..c24cf71077d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -0,0 +1,144 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Greater; +import org.tensorflow.op.nn.LeakyRelu; +import org.tensorflow.types.family.TNumber; + +/** + * Rectified Linear Unit(ReLU) activation. + * + *

With default values, this returns the standard ReLU activation: max(x, 0), the + * element-wise maximum of 0 and the input tensor. + * + *

Modifying default parameters allows you to use non-zero thresholds, change the max value of + * the activation, and to use a non-zero multiple of the input for values below the threshold. + * + *

For example: + * + *

+ *     Operand<TFloat32> input = tf.constant(
+ *              new float[] {-10f, -5f, 0.0f, 5f, 10f});
+ *
+ *     // With default parameters
+ *     ReLU<TFloat32> relu = new ReLU<>(tf);
+ *     Operand<TFloat32> result = relu.call(input);
+ *     // result is [0.f,  0.f,  0.f,  5.f, 10.f]
+ *
+ *     // With alpha = 0.5
+ *     relu = new ReLU<>(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT);
+ *     result = relu.call(input);
+ *     // result is [-5.f , -2.5f,  0.f ,  5.f , 10.f]
+ *
+ *     // With maxValue = 5
+ *     relu = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, 5f, ReLU.THRESHOLD_DEFAULT);
+ *     result = relu.call(input);
+ *     // result is [0.f, 0.f, 0.f, 5.f, 5.f]
+ *
+ *     // With threshold = 5
+ *     relu = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5f);
+ *     result = relu.call(input);
+ *     // result is [-0.f, -0.f,  0.f,  0.f, 10.f]
+ * 
+ * + * @param the data type of the result + */ +public class ReLU extends Activation { + + public static final float ALPHA_DEFAULT = 0.0f; + public static final float MAX_VALUE_DEFAULT = Float.NaN; + public static final float THRESHOLD_DEFAULT = 0.0f; + + private final float alpha; + private final float maxValue; + private final float threshold; + + /** + * Creates a new ReLU with alpha={@link #ALPHA_DEFAULT}, maxValue={@link #MAX_VALUE_DEFAULT}, + * threshold={@link #THRESHOLD_DEFAULT}, + * + * @param tf the TensorFlow Ops + */ + public ReLU(Ops tf) { + this(tf, ALPHA_DEFAULT, MAX_VALUE_DEFAULT, THRESHOLD_DEFAULT); + } + + /** + * Creates a new ReLU + * + * @param tf the TensorFlow Ops + * @param alpha governs the slope for values lower than the threshold. + * @param maxValue sets the saturation threshold (the largest value the function will return). + * @param threshold the threshold value of the activation function below which values will be + * damped or set to zero. + */ + public ReLU(Ops tf, float alpha, float maxValue, float threshold) { + super(tf); + this.alpha = alpha; + this.maxValue = maxValue; + this.threshold = threshold; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + + DataType dataType = input.asOutput().dataType(); + + boolean clipMax = !Float.isNaN(maxValue); + Operand negativePart = null; + if (alpha != 0) { + if (Float.isNaN(maxValue) && threshold == 0) { + return tf.nn.leakyRelu(input, LeakyRelu.alpha(alpha)); + } + if (threshold != 0) { + negativePart = + tf.nn.relu( + tf.math.add(tf.math.neg(input), tf.dtypes.cast(tf.constant(threshold), dataType))); + } else { + negativePart = tf.nn.relu(tf.math.neg(input)); + } + } + + Operand lInput; + if (threshold != 0) { + // computes input for input > threshold else 0 + Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), dataType)); + lInput = tf.math.mul(input, tf.dtypes.cast(greater, dataType)); + } else if (maxValue == 6) { + // if no threshold, then can use nn.relu6 native TF op for performance + lInput = tf.nn.relu6(input); + clipMax = false; + } else { + lInput = tf.nn.relu(input); + } + if (clipMax) { + Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), dataType); + Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + lInput = tf.clipByValue(lInput, zero, lmaxValue); + } + + if (alpha != 0.) { + lInput = + tf.math.sub( + lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), dataType), negativePart)); + } + return lInput; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java new file mode 100644 index 00000000000..f24731049fb --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Scaled Exponential Linear Unit (SELU). + * + *

The Scaled Exponential Linear Unit (SELU) activation function is defined as: + * + *

    + *
  • if x > 0: return scale * x + *
  • if x < 0: return scale * alpha * (exp(x) - 1) + *
+ * + *

where alpha and scale are pre-defined constants ( + * alpha=1.67326324 and scale=1.05070098). + * + *

Basically, the SELU activation function multiplies scale (> 1) with the output + * of the elu function to ensure a slope larger than one for positive inputs. + * + *

The values of alpha and scale are chosen so that the mean and + * variance of the inputs are preserved between two consecutive layers as long as the weights are + * initialized correctly (see {@link org.tensorflow.framework.initializers.LeCun} with Normal + * Distribution) and the number of input units is "large enough" + * + *

Notes: To be used together with the {@link + * org.tensorflow.framework.initializers.LeCun} initializer with Normal Distribution. + * + * @param the data type of the activation + * @see Klambauer et al., 2017 + */ +public class SELU extends Activation { + + /** + * Creates a Scaled Exponential Linear Unit (SELU) activation. + * + * @param tf the TensorFlow Ops + */ + public SELU(Ops tf) { + super(tf); + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + return tf.nn.selu(input); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java new file mode 100644 index 00000000000..5d507b38483 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Sigmoid activation. sigmoid(x) = 1 / (1 + exp(-x)). + * + *

Applies the sigmoid activation function. For small values (<-5), sigmoid + * returns a value close to zero, and for large values (>5) the result of the function gets close + * to 1. + * + *

Sigmoid is equivalent to a 2-element Softmax, where the second element is assumed to be zero. + * The sigmoid function always returns a value between 0 and 1. + * + *

For example: + * + *

+ *     Operand<TFloat32> input = tf.constant(
+ *              new float[] {-20f, -1.0f, 0.0f, 1.0f, 20f});
+ *     Sigmoid<TFloat32> sigmoid = new Sigmoid<>(tf);
+ *     Operand<TFloat32> result = sigmoid.call(input);
+ *     // result is [2.0611537e-09f, 2.6894143e-01f,
+ *     //                 5.0000000e-01f,7.3105860e-01f, 1.f]
+ * 
+ * + * @param the data type of the activation + */ +public class Sigmoid extends Activation { + + /** + * Creates a Sigmoid activation. + * + * @param tf the TensorFlow Ops + */ + public Sigmoid(Ops tf) { + super(tf); + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + return tf.math.sigmoid(input); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java new file mode 100644 index 00000000000..d31eebd9007 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceMax; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TFloating; + +/** + * Softmax converts a real vector to a vector of categorical probabilities. + * + *

The elements of the output vector are in range (0, 1) and sum to 1. + * + *

Each vector is handled independently. The axis argument sets which axis of the + * input the function is applied along. + * + *

Softmax is often used as the activation for the last layer of a classification network because + * the result could be interpreted as a probability distribution. + * + *

The softmax of each vector x is computed as: exp(x) / tf.sum(exp(x)). + * + *

The input values in are the log-odds of the resulting probability. + * + * @param the data type of the activation + */ +public class Softmax extends Activation { + + private static final int AXIS_DEFAULT = -1; + + private final int axis; + + /** + * Creates a softmax activation where the default axis is {@link #AXIS_DEFAULT} which indicates + * the last dimension. + * + * @param tf the TensorFlow Ops + */ + public Softmax(Ops tf) { + this(tf, AXIS_DEFAULT); + } + + /** + * Creates a Softmax activation + * + * @param tf the TensorFlow Ops + * @param axis The dimension softmax would be performed on. + */ + public Softmax(Ops tf, int axis) { + super(tf); + this.axis = axis; + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + Shape shape = input.asOutput().shape(); + int numDimensions = shape.numDimensions(); + if (numDimensions == 2) { + return tf.nn.softmax(input); + } else { + Operand e = + tf.math.exp( + tf.math.sub(input, tf.reduceMax(input, tf.constant(axis), ReduceMax.keepDims(true)))); + Operand s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true)); + return tf.math.div(e, s); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java new file mode 100644 index 00000000000..65a183ea047 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Softplus activation function, softplus(x) = log(exp(x) + 1). + * + *

Example Usage: + * + *

+ *     Operand<TFloat32> input = tf.constant(
+ *              new float[] {-20f, -1.0f, 0.0f, 1.0f, 20f});
+ *     Softplus<TFloat32> softplus = new Softplus<>(tf);
+ *     Operand<TFloat32> result = softplus.call(input);
+ *     // result is [2.0611537e-09f, 3.1326166e-01f, 6.9314718e-01f,
+ *     //                 1.3132616e+00f, 2.0000000e+01f]
+ * 
+ */ +public class Softplus extends Activation { + + /** + * Creates a Softplus activation function. + * + * @param tf the TensorFlow Ops + */ + public Softplus(Ops tf) { + super(tf); + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + return tf.math.softplus(input); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java new file mode 100644 index 00000000000..1f691e71862 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Softsign activation function, softsign(x) = x / (abs(x) + 1). + * + *

Example Usage: + * + *

+ *     Operand<TFloat32> input = tf.constant(
+ *              new float[] {-1.0f, 0.0f, 1.0f});
+ *     Softsign<TFloat32> softsign = new Softsign<>(tf);
+ *     Operand<TFloat32> result = softsign.call(input);
+ *     // result is [-0.5f, 0.f, 0.5f]
+ * 
+ * + * @param the data type of the activation + */ +public class Softsign extends Activation { + + /** + * Creates a Softsign activation. + * + * @param tf the TensorFlow Ops + */ + public Softsign(Ops tf) { + super(tf); + } + + /** + * Gets the calculation operation for the activation. + * + * @param input the input tensor + * @return The operand for the activation + */ + @Override + public Operand call(Operand input) { + return tf.nn.softsign(input); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java new file mode 100644 index 00000000000..d9f73a422d5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Swish activation function. swish(x) = x * sigmoid(x). + * + *

Swish activation function which returns x*sigmoid(x). It is a smooth, + * non-monotonic function that consistently matches or outperforms ReLU on deep + * networks, it is unbounded above and bounded below. + * + *

Example Usage: + * + *

+ *     Operand<TFloat32> input = tf.constant(new float[]
+ *                                        {-20, -1.0, 0.0, 1.0, 20});
+ *     Swish<TFloat32> swish = new Swish<>(tf);
+ *     Operand<TFloat32> result = swish.call(input);
+ *     // result = [-4.1223075e-08f, -2.6894143e-01f,  0.0000000e+00f,
+ *     //          7.3105860e-01f,  2.0000000e+01f ]
+ *
+ * 
+ * + * @param the data type of the activation + * @see Ramachandran et al., 2017 + */ +public class Swish extends Activation { + + /** + * Creates a Swish activation, swish(x) = x * sigmoid(x). + * + *

Swish activation function which returns x*sigmoid(x). It is a smooth, + * non-monotonic function that consistently matches or outperforms ReLU on deep networks, it is + * unbounded above and bounded below. + * + * @param tf the TensorFlow Ops + */ + public Swish(Ops tf) { + super(tf); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + + // TODO Python Keras returns a "grad", which is an optimization not implemented in Java. + return tf.math.mul(input, tf.math.sigmoid(input)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java new file mode 100644 index 00000000000..4fe02eed048 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +/** + * Hyperbolic tangent activation function. + * + *

For example: + * + *

+ *     Operand<TFloat32> input = tf.constant(new float[]
+ *                                        {-3.0f,-1.0f, 0.0f, 1.0f, 3.0f});
+ *     Tanh<TFloat32> tanh = new Tanh<>(tf);
+ *     Operand<TFloat32> result = tanh.call(input);
+ *     // result = [-0.9950547f, -0.7615942f,  0.f,  0.7615942f,  0.9950547f]
+ * 
+ * + * @param the data type of the activation + */ +public class Tanh extends Activation { + + /** + * Creates a Hyperbolic tangent activation. + * + * @param tf the TensorFlow Ops + */ + public Tanh(Ops tf) { + super(tf); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + return tf.math.tanh(input); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java new file mode 100644 index 00000000000..e608224a50d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class ELUTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public ELUTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + + /** Test of ELU call method */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + float[] expected = {1f, -0.86466473f, 3f, -0.9816844f, -0.63212055f, 2f, -0.95021296f, 4f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ELU instance = new ELU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of ELU call method */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + double[] expected = {1, -0.86466473, 3, -0.9816844, -0.63212055, 2, -0.95021293, 4F}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ELU instance = new ELU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of ELU call method */ + @Test + public void testAlpha() { + double[] input = {1, -2, 3, -4, -5, 6, 7, 8}; + double[] expected = {1, -1.7293295, 3, -1.9633688, -1.9865241, 6, 7, 8}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ELU instance = new ELU<>(tf, 2.0f); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java new file mode 100644 index 00000000000..a0fd1f60b47 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class ExponentialTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public ExponentialTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + + /** Test of Exponential call method. */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + float[] expected = { + 2.7182817F, + 0.13533528F, + 20.085537F, + 0.01831564F, + 0.36787945F, + 7.389056F, + 0.049787067F, + 54.598152F + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Exponential instance = new Exponential<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** TTest of Exponential call method. */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + double[] expected = { + 2.7182818284590455, 0.1353352832366127, 20.085536923187668, + 0.018315638888734182, 0.3678794411714423, 7.38905609893065, + 0.049787068367863944, 54.598150033144236, + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Exponential instance = new Exponential<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java new file mode 100644 index 00000000000..b1eaab8de22 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class HardSigmoidTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public HardSigmoidTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + + /** Test of HardSigmoid call method. */ + @Test + public void testCallFloat() { + float[] input = {-3.0f, -1.0f, 0.0f, 1.0f, 3.0f}; + float[] expected = {0.f, 0.3f, 0.5f, 0.7f, 1.f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + HardSigmoid instance = new HardSigmoid<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of HardSigmoid call method. */ + @Test + public void testCallDouble() { + double[] input = {-3.0, -1.0, 0.0, 1.0, 3.0}; + double[] expected = {0., 0.3, 0.5, 0.7, 1.}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + HardSigmoid instance = new HardSigmoid<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java new file mode 100644 index 00000000000..7974035c680 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +/** @author Jim Clarke */ +public class LinearTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public LinearTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of Linear call method. */ + @Test + public void testCallInt() { + int[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + int[] expected = {1, -2, 3, -4, -1, 2, -3, 4}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Linear instance = new Linear<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Linear call method. */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + float[] expected = {1, -2, 3, -4, -1, 2, -3, 4}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Linear instance = new Linear<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Linear call method. */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + double[] expected = {1, -2, 3, -4, -1, 2, -3, 4}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Linear instance = new Linear<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java new file mode 100644 index 00000000000..f54401515ab --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.*; + +/** @author Jim Clarke */ +public class ReLUTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public ReLUTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of ReLU call method */ + @Test + public void testCallFloat() { + float[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + float[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of ReLU call method */ + @Test + public void testCallInt() { + int[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + int[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of ReLU call method */ + @Test + public void testCallLong() { + long[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + long[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of ReLU call method */ + @Test + public void testCallFloat16() { + float[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + float[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf); + Operand result = + instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.DTYPE)); + session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.DTYPE), result); + } + } + + /** Test of ReLU call method */ + @Test + public void testCallDouble() { + double[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + double[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + @Test + public void testAlpha() { + double[] input = {-10., -5., 0.0, 5., 10.}; + double[] expected = {-5. , -2.5, 0., 5., 10.}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + @Test + public void testMaxValue() { + double[] input = {-10., -5., 0.0, 5., 10.}; + double[] expected = {0., 0., 0., 5., 5.}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + @Test + public void testThreshold() { + double[] input = {-10., -5., 0.0, 5., 10.}; + double[] expected = {-0., -0., 0., 0., 10.}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java new file mode 100644 index 00000000000..caba5c43ba8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class SELUTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public SELUTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + + /** Test of SELU call method */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + float[] expected = { + 1.050701F, -1.5201665F, 3.152103F, -1.7258986F, -1.1113307F, 2.101402F, -1.6705687F, 4.202804F + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SELU instance = new SELU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of SELU call method */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + double[] expected = { + 1.0507009873554805, -1.520166468595695, 3.1521029620664414, + -1.7258986281898947, -1.1113307378125628, 2.101401974710961, + -1.670568728767112, 4.202803949421922, + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SELU instance = new SELU<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java new file mode 100644 index 00000000000..ffb16cf077a --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class SigmoidTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public SigmoidTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + /** Test of Sigmoid call method */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + float[] expected = { + 0.7310586F, + 0.11920291F, + 0.95257413F, + 0.017986238F, + 0.26894143F, + 0.8807971F, + 0.047425866F, + 0.98201376F, + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sigmoid instance = new Sigmoid<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Sigmoid call method */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -1, 2, -3, 4}; + double[] expected = { + 0.7310585786300049, 0.11920292202211755, 0.9525741268224334, + 0.01798620996209156, 0.2689414213699951, 0.8807970779778823, + 0.04742587317756678, 0.9820137900379085 + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sigmoid instance = new Sigmoid<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java new file mode 100644 index 00000000000..a3ff89cc407 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceMax; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class SoftmaxTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public SoftmaxTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + /** Test of Softmax method, of class Activations. */ + @Test + public void testSoftmaxOpsOperandFloat() { + float[][] input = {{1, 2, 3, 4}, {5, 6, 7, 8}}; + float[][] expected = { + {0.032059f, 0.087144f, 0.236883f, 0.643914f}, + {0.032059f, 0.087144f, 0.236883f, 0.643914f} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softmax instance = new Softmax<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of Softmax method, of class Activations. */ + @Test + public void testSoftmaxOpsOperandDouble() { + double[][] input = {{1, 2, 3, 4}, {5, 6, 7, 8}}; + double[][] expected = { + {0.032059, 0.087144, 0.236883, 0.643914}, + {0.032059, 0.087144, 0.236883, 0.643914} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softmax instance = new Softmax<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of Softmax method, of class Activations. */ + @Test + public void testSoftmaxOpsOperandDoubleNegative() { + double[][] input = {{1, -2, 3, -4}, {-5, 6, -7, 8}}; + double[][] expected = { + {1.18405115e-01, 5.89504354e-03, 8.74902034e-01, 7.97807387e-04}, + {1.99088704e-06, 1.19202653e-01, 2.69437261e-07, 8.80795087e-01} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softmax instance = new Softmax<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of Softmax method, of class Activations. */ + @Test + public void testSoftmax1D() { + double[] input = {1, -2, 3, -4, -5, 6, 7, 8}; + double[] expected = { + 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, + 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softmax instance = new Softmax<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of Softmax method, of class Activations. */ + @Test + public void testSoftmax3D() { + double[][][] input = {{{1, -2}, {3, -4}}, {{-5, 6}, {-7, 8}}}; + double[][][] expected = { + {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, + {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softmax instance = new Softmax<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java new file mode 100644 index 00000000000..a17f2650d62 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +/** @author Jim Clarke */ +public class SoftplusTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public SoftplusTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of Softplus call method */ + @Test + public void testCallFloat() { + float[] input = {1, 2, 3, 4, 5, 6, 7, 8}; + float[] expected = { + 1.3132616F, 2.126928F, 3.0485873F, 4.01815F, 5.0067153F, 6.0024757F, 7.0009117F, 8.000336F + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softplus instance = new Softplus<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Softplus call method */ + @Test + public void testCallDouble() { + double[] input = {1, 2, 3, 4, 5, 6, 7, 8}; + double[] expected = { + 1.3132616875182228, 2.1269280110429727, 3.048587351573742, + 4.0181499279178094, 5.006715348489118, 6.00247568513773, + 7.000911466453774, 8.000335406372896, + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softplus instance = new Softplus<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java new file mode 100644 index 00000000000..43591ab4761 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +/** @author Jim Clarke */ +public class SoftsignTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public SoftsignTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of Softsign call method */ + @Test + public void testCallFloat() { + float[] input = {1, 2, 3, 4, 5, 6, 7, 8}; + float[] expected = {0.5F, 0.6666667F, 0.75F, 0.8F, 0.8333333F, 0.85714287F, 0.875F, 0.8888889F}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softsign instance = new Softsign<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Softsign call method */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + double[] expected = { + 0.5, + -0.6666666666666666, + 0.75, + -0.8, + -0.8333333333333334, + 0.8571428571428571, + -0.875, + 0.8888888888888888 + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Softsign instance = new Softsign<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java new file mode 100644 index 00000000000..5739bccd3d5 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** @author Jim Clarke */ +public class SwishTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public SwishTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + + + /** Test of Swish call method */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + float[] expected = { + 0.7310585786300049f, + -.238405844f, + 2.8577223804673f, + -7.19448398e-02f, + -3.34642546e-02f, + 5.985164261060192f, + -6.37735836e-03f, + 7.997317198956269f + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Swish instance = new Swish<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Swish call method */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + double[] expected = { + 0.7310585786300049, + -.238405844, + 2.8577223804673, + -7.19448398e-02, + -3.34642546e-02, + 5.985164261060192, + -6.37735836e-03, + 7.997317198956269 + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Swish instance = new Swish<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java new file mode 100644 index 00000000000..5162e141c44 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.*; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +/** @author Jim Clarke */ +public class TanhTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + public TanhTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of Tanh call method. */ + @Test + public void testCallFloat() { + float[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + float[] expected = { + 0.76159416F, -0.96402758F, + 0.99505475F, -0.9993293F, + -0.9999092F, 0.99998771F, + -0.99999834F, 0.99999977F + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Tanh instance = new Tanh<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } + + /** Test of Tanh call method. */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + double[] expected = { + 0.76159416, -0.96402758, + 0.99505475, -0.9993293, + -0.9999092, 0.99998771, + -0.99999834, 0.99999977 + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Tanh instance = new Tanh<>(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 9fb9885505c..bca90211e50 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -125,6 +125,16 @@ public void evaluate(double expected, Operand input) { } index.set(0); o.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + } else if (dtype == TUint8.DTYPE) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.data() + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.data().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } @@ -191,6 +201,18 @@ public void evaluate(Number[] expected, Output input) { o.data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); + } else if (dtype == TUint8.DTYPE) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.data() + .scalars() + .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.data() + .scalars() + .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } } @@ -250,6 +272,47 @@ public void evaluate(FloatNdArray expected, Output input) { .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); + } else if (dtype == TUint8.DTYPE) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.data() + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.data() + .scalars() + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); + } + } + + /** {@inheritDoc} */ + @Override + public void evaluateString(Output input, Predicate predicate) { + AtomicInteger index = new AtomicInteger(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %s\n", predicate.test(input.data().getObject()), input.data().getObject()); + } else { + input + .data() + .scalars() + .forEachIndexed( + (idx, s) -> + System.out.printf( + "%d). %b <==> %s\n", + index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); + } + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(input.data().getObject())); + } else { + input.data().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } @@ -307,6 +370,30 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getDouble()))); } + } else if (dtype == TFloat16.DTYPE) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.data().getFloat()), o.data().getFloat()); + } else { + o.data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(o.data().getFloat())); + } else { + o.data() + .scalars() + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat()))); + } } else if (dtype == TInt32.DTYPE) { Output o = (Output) input; if (debug) { @@ -355,6 +442,30 @@ public void evaluate(Output input, Predicate predic .scalars() .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getLong()))); } + } else if (dtype == TUint8.DTYPE) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %x\n", predicate.test(o.data().getByte()), o.data().getByte()); + } else { + o.data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %x\n", + index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + } + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(o.data().getByte())); + } else { + o.data() + .scalars() + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getByte()))); + } } else { fail("Unexpected DataType: " + dtype); } @@ -515,6 +626,31 @@ public void evaluate(Output expected, Output input) { .scalars() .forEachIndexed((idx, f) -> assertEquals(x.data().getLong(idx), f.getLong())); } + } else if (dtype == TUint8.DTYPE) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %x <==> %x\n", x.data().getByte(), o.data().getByte()); + } else { + o.data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %x <==> %x\n", + index.getAndIncrement(), x.data().getByte(idx), f.getByte())); + } + } + index.set(0); + if (isScalar) { + assertEquals(x.data().getByte(), o.data().getByte()); + } else { + o.data() + .scalars() + .forEachIndexed((idx, f) -> assertEquals(x.data().getByte(idx), f.getByte())); + } } else if (dtype == TString.DTYPE) { Output x = (Output) expected; Output o = (Output) input; @@ -596,6 +732,12 @@ public void print(PrintWriter writer, Output input) { o.data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } else if (dtype == TUint8.DTYPE) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.data() + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } else if (dtype == TString.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 0416267ae59..33ddec6dce3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -169,6 +169,22 @@ public void evaluate(double expected, Operand input) { this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { result.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } + } else if (dtype == TUint8.DTYPE) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + result + .data() + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); + } } else { fail("Unexpected DataType: " + dtype); } @@ -265,6 +281,25 @@ public void evaluate(Number[] expected, Output input) { .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } + } else if (dtype == TUint8.DTYPE) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + result + .data() + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + result + .data() + .scalars() + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); + } } else { fail("Unexpected DataType: " + dtype); } @@ -358,6 +393,26 @@ public void evaluate(FloatNdArray expected, Output input) { .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); } + } else if (dtype == TUint8.DTYPE) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + result + .data() + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); + } } else { fail("Unexpected DataType: " + dtype); } @@ -515,6 +570,46 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.data().getDouble(idx), f.getDouble(), epsilon)); } } + } else if (dtype == TFloat16.DTYPE) { + final Output finalExpected = (Output) expected; + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE); + Tensor expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.data().getFloat(idx), + f.getFloat())); + } + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE); + Tensor expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE)) { + if (isScalar) { + assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals(expectedResult.data().getFloat(idx), f.getFloat(), epsilon)); + } + } } else if (dtype == TInt32.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { @@ -592,6 +687,46 @@ public void evaluate(Output expected, Output input) { assertEquals(expectedResult.data().getLong(idx), f.getLong(), epsilon)); } } + } else if (dtype == TUint8.DTYPE) { + final Output finalExpected = (Output) expected; + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE); + Tensor expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.data().getByte(), result.data().getByte()); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), + finalExpected.data().getByte(idx), + f.getByte())); + } + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE); + Tensor expectedResult = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + if (isScalar) { + assertEquals(expectedResult.data().getByte(), result.data().getByte(), epsilon); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals(expectedResult.data().getByte(idx), f.getByte(), epsilon)); + } + } } else if (dtype == TBool.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { @@ -675,6 +810,44 @@ public void evaluate(Output expected, Output input) { } } + /** {@inheritDoc} */ + @Override + public void evaluateString(Output input, Predicate predicate) { + boolean isScalar = input.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %s\n", + predicate.test(result.data().getObject()), result.data().getObject()); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %s\n", + index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); + } + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + if (isScalar) { + assertTrue(predicate.test(result.data().getObject())); + } else { + result + .data() + .scalars() + .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); + } + } + } + /** {@inheritDoc} */ @Override public void evaluate(Output input, Predicate predicate) { @@ -808,6 +981,38 @@ public void evaluate(Output input, Predicate predic .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getLong()))); } } + } else if (dtype == TUint8.DTYPE) { + if (debug) { + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", + predicate.test(result.data().getByte()), result.data().getByte()); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + } + } + } + index.set(0); + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + if (isScalar) { + assertTrue(predicate.test(result.data().getByte())); + } else { + result + .data() + .scalars() + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getByte()))); + } + } } else { fail("Unexpected DataType: " + dtype); } @@ -881,6 +1086,22 @@ public void print(PrintWriter writer, Output input) { (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } + } else if (dtype == TUint8.DTYPE) { + AtomicInteger index = new AtomicInteger(); + + try (Tensor result = + this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + if (isScalar) { + writer.printf( + "%d). %x\n", index.getAndIncrement(), ((Output) input).data().getByte()); + } else { + result + .data() + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + } } else if (dtype == TBool.DTYPE) { AtomicInteger index = new AtomicInteger(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index a0855eb6260..3fccd0f0506 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -37,12 +37,6 @@ public abstract class TestSession implements AutoCloseable { protected float epsilon = 1e-5F; protected boolean debug; - /** The Test Session mode, either Eager or Graph */ - public enum Mode { - EAGER, - GRAPH - } - /** * Creates an Eager Test Session * @@ -469,6 +463,24 @@ public void evaluate(Operand input, Predicate predi */ public abstract void evaluate(Output input, Predicate predicate); + /** + * Evaluates the input against the expected string value + * + * @param input the operand to evaluate + * @param predicate The Predicate that evaluates the each value from input + */ + public void evaluateString(Operand input, Predicate predicate) { + evaluateString(input.asOutput(), predicate); + } + + /** + * Evaluates the input against the expected string value + * + * @param input the operand to evaluate + * @param predicate The Predicate that evaluates the each value from input + */ + public abstract void evaluateString(Output input, Predicate predicate); + /** * Evaluates the input against the expected value * @@ -624,4 +636,10 @@ public boolean isDebug() { public void setDebug(boolean debug) { this.debug = debug; } + + /** The Test Session mode, either Eager or Graph */ + public enum Mode { + EAGER, + GRAPH + } }