diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..007ee9d0d42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 290e4e80b57..894bd073758 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -62,7 +62,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see VarianceScaling.Distribution diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 9b1a0887af0..3a91b72b0d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -57,7 +57,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 363291fa5cc..5aac163c1e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. x=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means + * that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -199,13 +202,14 @@ public CategoricalCrossentropy( * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" - * and axis=1 corresponds to data format "Channels First". - * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} + * and axis=1 corresponds to data format "Channels First". {@link + * Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( @@ -242,13 +246,12 @@ public CategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index f592c19f8bb..73837ed1756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1.

+ *

labels values are expected to be 0 or 1. * *

Standalone usage: * @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 137c7025c04..0a18d93caf3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -22,12 +22,13 @@ /** * Computes the cosine similarity between labels and predictions. * - *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 - * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to - * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you - * try to maximize the proximity between predictions and targets. If either labels or predictions is - * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and - * targets. + *

Note that it is a number between -1 and 1. When it is a negative + * number between -1 and 0, 0 indicates orthogonality and + * values closer to -1indicate greater similarity. The values closer to 1 + * indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels or + * predictions is a zero vector, cosine similarity will be 0 regardless of + * the proximity between predictions and targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) * @@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; - private final int axis; + private final int[] axis; /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis @@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) { this(tf, null, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a + * Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, int[] axis) { + + this(tf, null, axis, DEFAULT_REDUCTION); + } /** * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} @@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) { this(tf, name, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, String name, int[] axis) { + + this(tf, name, axis, DEFAULT_REDUCTION); + } + /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an * axis of {@link #DEFAULT_AXIS} @@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) { */ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + this(tf, null, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + this(tf, null, axis, reduction); } @@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { * @param reduction Type of Reduction to apply to the loss. */ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { + this(tf, name, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { super(tf, name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + losses = tf.math.neg(losses); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 88b4a7aa056..d4c350ef06c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,15 +18,16 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0)

. + *

loss = maximum(1 - labels * predictions, 0). * - *

labels values are expected to be -1 or 1. - * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -116,21 +117,19 @@ public Hinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand tLabels = cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 6d3e3f0c2ac..b1aee1b0656 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -89,6 +89,7 @@ public Huber(Ops tf) { * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Huber(Ops tf, String name) { this(tf, name, DELTA_DEFAULT, Reduction.AUTO); @@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) { * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Huber(Ops tf, String name, Reduction reduction) { @@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) { * Creates a Huber Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ @@ -130,8 +132,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 8cf3db8d518..2aa1f72092b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 1669669a768..a11d582e527 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -77,6 +77,7 @@ public LogCosh(Ops tf) { * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public LogCosh(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -96,7 +97,7 @@ public LogCosh(Ops tf, Reduction reduction) { * Creates a LogCosh Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public LogCosh(Ops tf, String name, Reduction reduction) { @@ -105,8 +106,8 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index ae33d5dfa37..cdd35d28aba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -62,10 +62,10 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels * @param predictions the predictions * @param The data type of the predictions and loss. - * @param The data type of the labels. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { return call(labels, predictions, null); } @@ -82,11 +82,10 @@ public Operand call(Operand labels, * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 0d25bd5e7e2..9aa94cf7fcf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -48,11 +48,10 @@ public class Losses { * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute error */ - public static Operand meanAbsoluteError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsoluteError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -70,11 +69,10 @@ public static Operand meanAbsoluteErro * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared error */ - public static Operand meanSquaredError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -91,11 +89,10 @@ public static Operand meanSquaredError * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute percentage error */ - public static Operand meanAbsolutePercentageError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsolutePercentageError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -105,8 +102,10 @@ public static Operand meanAbsolutePerc tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum( + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul( + cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -118,11 +117,10 @@ public static Operand meanAbsolutePerc * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared logarithmic percentage error */ - public static Operand meanSquaredLogarithmicError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredLogarithmicError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -152,8 +150,12 @@ public static Operand meanSquaredLogar * @param the data type of the predictions and labels * @return the binary crossentropy loss. */ - public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + public static Operand binaryCrossentropy( + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -181,7 +183,7 @@ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible + /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { // TODO - this does not work // TODO output = backtrackIdentity(output); @@ -218,16 +220,17 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. */ - public static Operand categoricalCrossentropy( + public static Operand categoricalCrossentropy( Ops tf, - Operand labels, + Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing, @@ -283,8 +286,8 @@ public static Operand categoricalCross * @param the data type of the predictions and labels * @return the categorical hinge loss */ - public static Operand categoricalHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand categoricalHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -329,8 +332,8 @@ public static Operand categoricalHinge * @param the data type of the predictions and labels * @return the cosine similarity loss */ - public static Operand cosineSimilarity( - Ops tf, Operand labels, Operand predictions, int axis) { + public static Operand cosineSimilarity( + Ops tf, Operand labels, Operand predictions, int[] axis) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -339,8 +342,7 @@ public static Operand cosineSimilarity tLabels = l2Normalize(tf, tLabels, axis); predictions = l2Normalize(tf, predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); - Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.neg(sum); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } /** @@ -355,8 +357,8 @@ public static Operand cosineSimilarity * @param the data type of the predictions and labels * @return the hinge loss */ - public static Operand hinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand hinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -391,8 +393,8 @@ public static Operand hinge( * @param the data type of the predictions and labels * @return the Huber loss */ - public static Operand huber( - Ops tf, Operand labels, Operand predictions, float delta) { + public static Operand huber( + Ops tf, Operand labels, Operand predictions, float delta) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -422,8 +424,8 @@ public static Operand huber( * @see Kullback?Leibler * divergence */ - public static Operand kullbackLeiblerDivergence( - Ops tf, Operand labels, Operand predictions) { + public static Operand kullbackLeiblerDivergence( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -452,8 +454,8 @@ public static Operand kullbackLeiblerD * @param the data type of the predictions and labels * @return the hyperbolic cosine divergence loss */ - public static Operand logCosh( - Ops tf, Operand labels, Operand predictions) { + public static Operand logCosh( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -480,8 +482,8 @@ public static Operand logCosh( * @param the data type of the predictions and labels * @return the Poisson loss */ - public static Operand poisson( - Ops tf, Operand labels, Operand predictions) { + public static Operand poisson( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -507,8 +509,12 @@ public static Operand poisson( * @param the data type of the predictions and labels * @return the sparse categorical crossentropy loss */ - public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + public static Operand sparseCategoricalCrossentropy( + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -553,7 +559,7 @@ public static Operand sparseCategorica int labelsRank = labelsShape.numDimensions(); boolean updateShape = labelsRank != predictionsRank - 1; - if (updateShape) { // TODO check to see if this is right + if (updateShape) { Shape newShape = labelsShape.take(labelsRank - 1); iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = @@ -584,8 +590,8 @@ public static Operand sparseCategorica * @param the data type of the predictions and labels * @return the squared hinge loss */ - public static Operand squaredHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand squaredHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -649,14 +655,14 @@ private static Operand smoothCategoricalLabels( * @param tf The TensorFlow Ops * @param x the input * @param axis Dimension along which to normalize. + * @param the data type for the input and the result * @return the normalized values based on L2 norm */ - public static Operand l2Normalize(Ops tf, Operand x, int axis) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index a2d5d5f8efc..03a3cf70110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -95,8 +95,8 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 49133df610b..6c5242df4f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -95,8 +95,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 2a6c2be885e..f975db55c44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -95,8 +95,8 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 2604e226b81..11b8e157e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -95,8 +95,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index c43be4f2821..78324acf8a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -76,6 +76,7 @@ public Poisson(Ops tf) { * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Poisson(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -95,7 +96,7 @@ public Poisson(Ops tf, Reduction reduction) { * Creates a Poisson Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Poisson(Ops tf, String name, Reduction reduction) { @@ -104,8 +105,8 @@ public Poisson(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.poisson(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index ea765e6f8fd..d04cc67d5d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -79,7 +80,8 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ @@ -88,8 +90,8 @@ public SparseCategoricalCrossentropy(Ops tf) { } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -122,8 +124,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and - * fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -135,7 +137,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -176,9 +179,10 @@ public SparseCategoricalCrossentropy( /** * Generates an Operand the calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. @@ -190,23 +194,22 @@ public SparseCategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesHelper.rangeCheck( - getTF(), - "predictions range check [0-1]", - predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + LossesHelper.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 4ad4c1c726c..dadbdb3b95e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -25,8 +26,8 @@ * *

loss = square(maximum(1 - labels * predictions, 0)) * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be - * converted to -1 or 1. + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -107,7 +108,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor @@ -117,21 +118,23 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 2104937a979..f811549fbca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return labels, target, and sampleWeights + * A helper class for loss methods to return labels, target, and sampleWeights * * @param the data type of the LossTuple entries. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 10067db91ba..f6b0de71b0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -32,8 +32,9 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to - * TensorFlow Java. These methods should not be used outside of the losses and metrics packages. + * These are helper methods for Losses and Metrics and will be module private when Java modularity + * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics + * packages. */ public class LossesHelper { @@ -42,16 +43,17 @@ public class LossesHelper { * *

    *
  1. Squeezes last dim of predictions or labels if their rank - * differs by 1 (using {@link #removeSqueezableDimensions}).
  2. + * differs by 1 (using {@link #removeSqueezableDimensions}). *
  3. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from * the new rank of predictions. If sampleWeight is scalar, it is - * kept scalar.
  4. + * kept scalar. *
* * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . + * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, @@ -77,12 +79,14 @@ public static LossTuple squeezeOrExpandDimensions( * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * * prediction. - * @return LossTuple of predictions, labels and sampleWeight. - * Each of them possibly has the last dimension squeezed, sampleWeight could be - * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are - * returned. + * @param the data type for the labels, predictions and result + * @return LossTuple of predictions, labels and sampleWeight + * . Each of them possibly has the last dimension squeezed, sampleWeight + * could be extended by one dimension. If sampleWeight is null, only the possibly + * shape modified predictions and labels are returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -178,6 +182,7 @@ private static Operand maybeExpandWeights( * @param labels Label values, a Tensor whose dimensions match predictions * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -193,6 +198,7 @@ public static LossTuple removeSqueezableDimensions( * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -216,7 +222,8 @@ public static LossTuple removeSqueezableDimensions( } // Use dynamic rank. - // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze @@ -298,8 +305,7 @@ private static Operand reduceWeightedLoss( public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.type())); + return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); } /** @@ -383,8 +389,7 @@ public static Operand rangeCheck( */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { - Operand flatValues = - tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); + Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().shape().size(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 651a6fac0b0..48ee244eafb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,17 +21,18 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * *

This is the crossentropy metric class to be used when there are only two label classes (0 and * 1). * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -41,7 +42,8 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing @@ -60,7 +62,10 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index c330ea88eaa..b22e5415f79 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. @@ -30,11 +32,10 @@ * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class CategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -48,7 +49,8 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to + * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +70,8 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -98,8 +101,11 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( - getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 2741a36edb6..4266cc487c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,13 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper +public class CategoricalHinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +47,10 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.categoricalHinge(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.categoricalHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 458de092bec..840f255c5ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,18 +15,20 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the cosine similarity metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; @@ -76,8 +78,12 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Metrics.cosineProximity(getTF(), labels, predictions, axis); + public Operand call( + Operand labels, Operand predictions) { + // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, + // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index baf9ad8ab7d..46ccd2859ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper - implements LossMetric { +public class Hinge extends MeanMetricWrapper implements LossMetric { /** * Creates a Hinge metric @@ -46,7 +46,10 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.hinge(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.hinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index efcbbcbb7f0..9ffcd6189f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,15 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper - implements LossMetric { +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** * Creates a KLDivergence metric @@ -47,7 +47,10 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 3df8505d54b..59e24f57110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -21,15 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper - implements LossMetric { +public class LogCoshError extends MeanMetricWrapper implements LossMetric { /** * Creates a LogCoshError metric @@ -47,7 +47,10 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.logCosh(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.logCosh(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index de1f5a5629e..8902b329bcc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -21,10 +21,9 @@ /** * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * - * @param The data type for the metric values * @param The data type for the metric result */ -public class Mean extends Reduce { +public class Mean extends Reduce { /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index e27676932ff..1cc6d0b6f99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,13 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper +public class MeanAbsoluteError extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +47,10 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsoluteError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 84fa9b627b2..8c6720b58f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,14 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +47,10 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c7edd6ebe93..3c4c79d39ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,13 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper +public class MeanSquaredError extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +47,10 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 199b6e0e114..d525bb76648 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,14 +21,15 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +47,10 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index bbb2aa73da2..468919e696d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -25,10 +25,9 @@ /** * Base class for Metrics * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Metric { +public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; @@ -75,10 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state - * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -90,13 +89,13 @@ public List updateStateList(Operand values, Operand the data type for the labels - * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -105,10 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -119,12 +118,12 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the labels - * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, + Operand predictions, + Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -149,10 +148,9 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies - * @param the data type for the sampleWeights. */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); @@ -186,7 +184,11 @@ public String getName() { return name; } - /** The random number generator seed value */ + /** + * Gets the random number generator seed value + * + * @return the random number generator seed value + */ public long getSeed() { return seed; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 0169bc6b8bc..95b74bf1eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -17,7 +17,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; @@ -46,89 +45,14 @@ public class Metrics { * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. * @param the data type for the predictions and results - * @param the data type ofr the labels. * @return the Operand for the Top K categorical accuracy value. */ - public static Operand topKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, long k) { + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); return CastHelper.cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } - - /** - * Computes the cosine similarity between labels and predictions. - * - * @param tf the TensorFlow Ops - * @param labels The ground truth values. - * @param predictions The prediction values. - * @param axes The dimensions along which the cosine similarity is computed. - * @param the data type for the labels - * @param the data type for the predictions and result - * @return Cosine similarity value. - */ - public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axes) { - Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axes); - - Operand predictionsNorm = l2Normalize(tf, predictions, axes); - Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); - } - - /** - * Normalizes along dimension axis using an L2 norm with an epsilon of {@link - * #L2_NORM_EPSILON}. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param The data type for x. - * @return the normalized values of x. - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { - return l2Normalize(tf, x, axes, L2_NORM_EPSILON); - } - - /** - * Normalizes along dimension axis using an L2 norm. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the - * divisor if norm < sqrt(epsilon). - * @param The data type for the values. - * @return the normalized values of x. - */ - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes, float epsilon) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); - Operand y = - tf.math.rsqrt( - tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); - return tf.math.mul(x, y); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 75a2031fbb5..422fd4808ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper - implements LossMetric { +public class Poisson extends MeanMetricWrapper implements LossMetric { /** * Creates a Poisson metric @@ -46,7 +46,10 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.poisson(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.poisson(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 2e01f722de6..e954169b2af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,15 +21,16 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; @@ -39,7 +40,8 @@ public class SparseCategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -55,7 +57,10 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 430dbbcc229..19b3b1d0ac4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the squared hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** * Creates a SquaredHinge metric @@ -46,7 +46,10 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.squaredHinge(getTF(), labels, predictions); + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.squaredHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index b7b87d313aa..1fb3d3bb580 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -29,8 +29,7 @@ public interface LossMetric { * * @param labels the truth values or labels * @param predictions the predictions - * @param The data type of the labels. * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 17c209a8fed..9a532a0294f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -17,13 +17,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.metrics.Mean; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. @@ -32,10 +33,9 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class MeanMetricWrapper extends Mean { +public class MeanMetricWrapper extends Mean { /** The loss function interface */ protected LossMetric loss; @@ -85,22 +85,21 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList( - CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList(cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index ad8ff58e417..8a352322f52 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -21,12 +21,10 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; @@ -57,13 +55,13 @@ public class MetricsHelper { * @param values the values to which weights are applied. * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values - * @param the type of Operand + * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") - public static Op assertBroadcastable( - Ops tf, Operand sampleWeights, Operand values) { + public static Op assertBroadcastable( + Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match @@ -129,7 +127,7 @@ public static Op assertBroadcastable( // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. - Operand reshapedWeights = + Operand reshapedWeights = tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); weightsShape = tf.shape(reshapedWeights); weightsRank = tf.rank(reshapedWeights); @@ -237,11 +235,10 @@ public static Operand mean(Ops tf, Operand x) { * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axes. * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axes) { + public static Operand mean( + Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } @@ -257,31 +254,27 @@ public static Operand mean( * @param the type of the operand * @return the mean of elements of x. */ - public static Operand mean( - Ops tf, Operand x, boolean keepDims) { + public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } - - /** * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @param the data type of the Operand - * @param the data type of the axes * @return the mean of elements of x. */ - - public static Operand mean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { if (axes == null) { - axes = (Operand) allAxes(tf, x); + axes = allAxes(tf, x); } return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } @@ -294,7 +287,7 @@ public static Operand mean( * @param x the Operand used to calculate the mean * @return the mean of the operand containing floating point numbers */ - public static Operand booleanMean(Ops tf, Operand x) { + public static Operand booleanMean(Ops tf, Operand x) { return booleanMean(tf, x, null, false); } @@ -305,44 +298,43 @@ public static Operand booleanMean(Ops tf, Operand x) { * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param the type of the axes. * @return the mean of the operand, along the specified axes, containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x,Operand axes) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes) { return booleanMean(tf, x, axes, false); } /** * Calculates the mean of the boolean operand, alongside all axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, boolean keepDims) { + public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); } /** * Calculates the mean of the boolean operand, alongside the specified axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { Operand xf = cast(tf, x, TFloat64.class); return mean(tf, xf, axes, keepDims); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 8e48cb4e573..2a26967b9f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -19,7 +19,6 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -29,13 +28,14 @@ import java.util.ArrayList; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Encapsulates metrics that perform a reduce operation on the metric values. * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Reduce extends Metric { +public abstract class Reduce extends Metric { public static final String TOTAL = "total"; public static final String COUNT = "count"; protected final MetricReduction reduction; @@ -45,8 +45,10 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values. - * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + /** + * the variable that holds the count of the metric values. For {@link + * MetricReduction#WEIGHTED_MEAN}, this count may be weighted + */ protected Variable count; /** @@ -95,12 +97,10 @@ private void setupVars() { public Op resetStates() { List controls = new ArrayList<>(); if (total != null) { - controls.add( - getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + controls.add(getTF().assign(total, cast(getTF(), getTF().constant(0), total.type()))); } if (count != null) { - controls.add( - getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + controls.add(getTF().assign(count, cast(getTF(), getTF().constant(0), count.type()))); } return getTF().withControlDependencies(controls).noOp(); } @@ -115,67 +115,67 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); } + Ops tf = getTF(); List updateOperations = new ArrayList<>(); // cast everything to match the variables - Operand lSampleWeights = null; - Operand lValues = values; + Operand tSampleWeights = null; + Operand tValues = cast(tf, values, getResultType()); if (sampleWeights != null) { - lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); - lValues = tuple.getTarget(); - lSampleWeights = tuple.getSampleWeights(); + tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); try { - lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples - int valuesRank = lValues.shape().numDimensions(); - int weightsRank = lSampleWeights.shape().numDimensions(); + int valuesRank = tValues.shape().numDimensions(); + int weightsRank = tSampleWeights.shape().numDimensions(); int numAxes = Math.min(0, valuesRank - weightsRank); if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + tValues = getTF().reduceSum(tValues, getTF().constant(axes)); } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + tValues = getTF().math.mean(tValues, getTF().constant(axes)); } } } - lValues = getTF().math.mul(lValues, lSampleWeights); + tValues = getTF().math.mul(tValues, tSampleWeights); } - Operand weightedValueSum = - getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); break; case WEIGHTED_MEAN: - if (lSampleWeights == null) { - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + if (tSampleWeights == null) { + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); } else { numValues = - CastHelper.cast( + cast( getTF(), getTF() - .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), resultType); } break; @@ -202,7 +202,7 @@ public Operand result() { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + fResult = getTF().math.divNoNan(total, cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 1841c7ee238..467dea19b57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -25,33 +25,6 @@ /** Implementation of set operations */ public class SetsOps { - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } - /** * Computes set difference of elements in last dimension of a and b with * aMinusB set to true. @@ -69,6 +42,7 @@ public String getSetOperation() { public static Operand difference(Ops tf, Operand a, Operand b) { return difference(tf, a, b, true); } + /** * Computes set difference of elements in last dimension of a and b. * @@ -143,4 +117,31 @@ public static Operand setOperation( setOperationResult.resultValues(), cast(tf, tf.constant(0), a.type())); } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 822eb490f22..aadbfeea54b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -31,29 +31,29 @@ * learning rate per dimension to address two drawbacks: * *

    - *
  • the continual decay of learning rates throughout training - *
  • the need for a manually selected global learning rate + *
  • the continual decay of learning rates throughout training + *
  • the need for a manually selected global learning rate *
* - *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a - * moving window of gradient updates, instead of accumulating all past gradients. This way, - * Adadelta continues learning even when many updates have been done. Compared to Adagrad, in - * the original version of Adadelta you don't have to set an initial learning rate. In this - * version, initial learning rate can be set, as in most other optimizers. + *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving + * window of gradient updates, instead of accumulating all past gradients. This way, Adadelta + * continues learning even when many updates have been done. Compared to Adagrad, in the original + * version of Adadelta you don't have to set an initial learning rate. In this version, initial + * learning rate can be set, as in most other optimizers. * - *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes - * converge to 1 which is effectively a high learning rate which would cause divergence. This - * occurs only near the end of the training as gradients and step sizes are small, and the - * epsilon constant in the numerator and denominator dominate past gradients and parameter - * updates which converge the learning rate to 1. + *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes + * converge to 1 which is effectively a high learning rate which would cause divergence. This occurs + * only near the end of the training as gradients and step sizes are small, and the epsilon constant + * in the numerator and denominator dominate past gradients and parameter updates which converge the + * learning rate to 1. * - *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers - * was trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The - * epsilon used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following - * construction: new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); + *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers was + * trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The epsilon + * used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following construction: + * new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); * * @see Zeiler, M., 2012 ADADELTA: An Adaptive Learning - * Rate Method. + * Rate Method */ public class AdaDelta extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 08f5f18a9cd..2dd05ef31b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -31,10 +31,10 @@ * how frequently a parameter gets updated during training. The more updates a parameter receives, * the smaller the updates. * - *

- * - * @see Duchi, J, et al., 2011, Adaptive Subgradient Methods for Online Learning and Stochastic Optimization - * @see Duchi, J, et al., 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1. + * @see Duchi, J, et al., 2011, + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization + * @see Duchi, J, et al., + * 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1 */ public class AdaGrad extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index df624e41c4e..7114c33339f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -40,7 +40,7 @@ * networks as it will require careful initialization of the gradient accumulators for it to train. * * @see Duchi, J, et al., 2011, - * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization */ public class AdaGradDA extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index cd95bb3bd07..0ecc1ac1451 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -32,12 +32,10 @@ public class Adamax extends Optimizer { public static final float EPSILON_DEFAULT = 1e-07f; public static final float BETA_ONE_DEFAULT = 0.9f; public static final float BETA_TWO_DEFAULT = 0.999f; - - private float learningRate; private final float betaOne; private final float betaTwo; private final float epsilon; - + private final float learningRate; private Constant learningRateConst; private Constant epsilonConst; private Constant betaOneConst; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 66314d2ffe0..5d8c1478231 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -13,10 +13,11 @@ /** * Optimizer that implements the FTRL algorithm. * + *

This version has support for both online L2 (the L2 penalty given in the paper below) and + * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * * @see McMahan, et - * al., 2013, Algorithm 1 - *

This version has support for both online L2 (the L2 penalty given in the paper above) and - * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * al., 2013, Algorithm 1 */ public class Ftrl extends Optimizer { @@ -29,13 +30,12 @@ public class Ftrl extends Optimizer { public static final float L1STRENGTH_DEFAULT = 0.0f; public static final float L2STRENGTH_DEFAULT = 0.0f; public static final float L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT = 0.0f; - - private float learningRate; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; private final float l2RegularizationStrength; private final float l2ShrinkageRegularizationStrength; + private final float learningRate; /** * Creates a Ftrl Optimizer diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index f9900a8ee78..5b94b548c0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -24,8 +24,6 @@ */ public class Nadam extends Optimizer { - private static final float DECAY_BASE = 0.96f; - private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float EPSILON_DEFAULT = 1e-8f; public static final float BETA_ONE_DEFAULT = 0.9f; @@ -33,7 +31,8 @@ public class Nadam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; public static final String MOMENTUM = "momentum"; - + private static final float DECAY_BASE = 0.96f; + private static final float DECAY = 0.004f; /** The learning rate. */ private final float learningRate; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index fdf56da4a67..ed141831bbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -71,14 +71,6 @@ protected Optimizer(Graph graph, String name) { this.globals = new ArrayList<>(); } - /** - * Gets the Optimizer's Ops instance - * @return the Optimizer's Ops instance - */ - public final Ops getTF() { - return tf; - } - /** * Creates a name by combining a variable name and a slot name * @@ -90,6 +82,15 @@ public static String createName(Output variable, String slotNam return variable.op().name() + "-" + slotName; } + /** + * Gets the Optimizer's Ops instance + * + * @return the Optimizer's Ops instance + */ + public final Ops getTF() { + return tf; + } + /** * Minimizes the loss by updating the variables * @@ -299,7 +300,8 @@ private Options() {} * Sets the shared name * * @param sharedName If non-empty, this variable is named in the given bucket with this - * shared_name. Otherwise, the node name is used instead. + * sharedName. Otherwise, the node name is used instead. + * @return this options instance */ public Optimizer.Options sharedName(String sharedName) { this.sharedName = sharedName; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index b3729dc367f..e86e64971a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -27,17 +27,20 @@ /** * Optimizer that implements the RMSProp algorithm. * - *

The gist of RMSprop is to:

    - *
  • Maintain a moving (discounted) average of the square of gradients - *
  • Divide the gradient by the root of this average
+ *

The gist of RMSprop is to: * - *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + *

    + *
  • Maintain a moving (discounted) average of the square of gradients + *
  • Divide the gradient by the root of this average + *
* - *

The centered version additionally maintains a moving average of the gradients, and uses - * that average to estimate the variance. + *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + * + *

The centered version additionally maintains a moving average of the gradients, and uses that + * average to estimate the variance. * * @see Hinton G, - * et al. 2012, lecture notes that is inexplicably the canonical reference. + * et al. 2012, lecture notes, that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { @@ -165,24 +168,20 @@ protected void createSlots(List> variables) { } } - /** - * Creates the RMSProp Slots for Root Mean Squared (RMS), - * MOMENTUM, and Mean Gradient (MG) + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) * * @param v the variable to install in the slot * @param the datatype of the variable. */ private void createRMSPropSlot(Output v) { - Operand rmsInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); + Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); + Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MG, mgInitializer); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java index b0fe48967dd..1c027cb5ddf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -34,7 +34,7 @@ public class CastHelper { */ @SuppressWarnings("unchecked") public static Operand cast( - Ops tf, Operand value, Class requiredType) { + Ops tf, Operand value, Class requiredType) { return (value.type() == requiredType) ? (Operand) value : tf.dtypes.cast(value, requiredType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 4ca2c789f28..e730c79cfbf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -14,8 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; -import org.tensorflow.ndarray.NdArray; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; @@ -33,7 +34,9 @@ public class ShapeUtils { /** * Converts a shape operand to a Shape object * + * @param scope the TensorFlow scope * @param dims the Operand containing the shape values + * @param the date type for the shape dimensions. * @return a new Shape based on an Operand that contains dimensions */ public static Shape toShape(Scope scope, Operand dims) { @@ -45,8 +48,8 @@ public static Shape toShape(Scope scope, Operand dims) * Converts a TInt32 type Operand to a Java int array * * @param scope the TensorFlow scope - * @param dims the TInt32 Operand - * @return the int array + * @param dims the shape dimensions operand + * @return the int array of the dimensions */ public static int[] getIntArray(Scope scope, Operand dims) { long[] longDims = getLongArray(scope, dims); @@ -66,8 +69,8 @@ public static long[] getLongArray(Scope scope, Operand if (scope.env().isEager()) { return getLongArray(dims.asTensor()); } - try (Session session = new Session((Graph)scope.env()); - TIntegral tensor = (TIntegral)session.runner().fetch(dims).run().get(0)) { + try (Session session = new Session((Graph) scope.env()); + TIntegral tensor = (TIntegral) session.runner().fetch(dims).run().get(0)) { return getLongArray(tensor); } } @@ -76,20 +79,21 @@ public static long[] getLongArray(Scope scope, Operand * Converts a TInt32 or TInt64 to a java long array * * @param dims the dimension tensor + * @param the type of the dimensions, must either be TInt32 or TInt64 type * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(T dims) { List result = new ArrayList<>(); if (dims instanceof TInt32) { - ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); + ((TInt32) dims).scalars().forEach(s -> result.add((long) s.getInt())); } else if (dims instanceof TInt64) { - ((TInt64)dims).scalars().forEach(s -> result.add(s.getLong())); + ((TInt64) dims).scalars().forEach(s -> result.add(s.getLong())); } else if (dims instanceof TUint8) { - ((TUint8)dims).scalars().forEach(s -> result.add(s.getObject().longValue())); - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } + ((TUint8) dims).scalars().forEach(s -> result.add(s.getObject().longValue())); + } else { // shouldn't happen + throw new IllegalArgumentException("the data type must be an integer type"); + } return result.stream().mapToLong(i -> i).toArray(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 7ceedded018..be46bb5c282 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -32,7 +32,7 @@ class BinaryCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -77,7 +77,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 0, 1, 0}; @@ -102,7 +102,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -128,7 +128,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java index 2b4a1d75467..34fc3eef884 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -31,7 +31,7 @@ class CategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -104,7 +104,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 0, 0, 1}; @@ -129,7 +129,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java index 87248d95e48..78b25a21b60 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -31,7 +31,7 @@ class CategoricalHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { @@ -64,7 +64,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index a9721ef2f8f..18410416c42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -31,7 +31,7 @@ class CosineSimilarityTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -80,7 +80,7 @@ public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); int axis = 1; - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 6af5fed4889..90531d21fde 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,8 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -55,8 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 28020c0fa1c..267578a492c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -31,7 +31,7 @@ class KLDivergenceTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java index 31c043e0473..1b5b8fb7d49 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -32,7 +32,7 @@ class LogCoshErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -56,7 +56,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java index 73241ecbe9f..984895f2ad9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -32,7 +32,7 @@ class MeanAbsoluteErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -74,7 +74,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java index 4c92844b217..0b9e7f6b538 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -34,7 +34,7 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -76,7 +76,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java index 0b760213015..e42052a9ef1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -33,7 +33,7 @@ class MeanSquaredErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); @@ -70,7 +70,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java index 098a5cb9725..e68d63b8778 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -32,7 +32,7 @@ class MeanSquaredLogarithmicErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -69,7 +69,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index cf3c3e44719..5631bac15ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -32,7 +32,7 @@ class PoissonTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 87af1bd8448..0aece8c8ac9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -32,7 +32,7 @@ class SparseCategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -56,7 +56,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); session.run(instance.resetStates()); @@ -105,7 +105,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java index e3376c224f3..2c80b3451ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -32,7 +32,7 @@ class SquaredHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { @@ -61,7 +61,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {