From b29edfdba6691276e3de3b0dee0d209d3eb215f3 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 11:16:37 -0500 Subject: [PATCH 1/6] Simplify generic parameters across losses and metrics. --- .../framework/losses/BinaryCrossentropy.java | 5 +- .../losses/CategoricalCrossentropy.java | 5 +- .../framework/losses/CategoricalHinge.java | 4 +- .../framework/losses/CosineSimilarity.java | 67 +++++++++++++--- .../tensorflow/framework/losses/Hinge.java | 4 +- .../tensorflow/framework/losses/Huber.java | 4 +- .../framework/losses/KLDivergence.java | 4 +- .../tensorflow/framework/losses/LogCosh.java | 4 +- .../org/tensorflow/framework/losses/Loss.java | 8 +- .../tensorflow/framework/losses/Losses.java | 73 ++++++++--------- .../framework/losses/MeanAbsoluteError.java | 4 +- .../losses/MeanAbsolutePercentageError.java | 4 +- .../framework/losses/MeanSquaredError.java | 4 +- .../losses/MeanSquaredLogarithmicError.java | 4 +- .../tensorflow/framework/losses/Poisson.java | 4 +- .../losses/SparseCategoricalCrossentropy.java | 5 +- .../framework/losses/SquaredHinge.java | 5 +- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../metrics/CategoricalCrossentropy.java | 13 +-- .../framework/metrics/CategoricalHinge.java | 5 +- .../framework/metrics/CosineSimilarity.java | 8 +- .../tensorflow/framework/metrics/Hinge.java | 5 +- .../framework/metrics/KLDivergence.java | 5 +- .../framework/metrics/LogCoshError.java | 5 +- .../tensorflow/framework/metrics/Mean.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 5 +- .../metrics/MeanAbsolutePercentageError.java | 7 +- .../framework/metrics/MeanSquaredError.java | 5 +- .../metrics/MeanSquaredLogarithmicError.java | 7 +- .../tensorflow/framework/metrics/Metric.java | 32 ++++---- .../tensorflow/framework/metrics/Metrics.java | 80 +------------------ .../tensorflow/framework/metrics/Poisson.java | 6 +- .../SparseCategoricalCrossentropy.java | 9 +-- .../framework/metrics/SquaredHinge.java | 5 +- .../framework/metrics/impl/LossMetric.java | 3 +- .../metrics/impl/MeanMetricWrapper.java | 21 +++-- .../framework/metrics/impl/MetricsHelper.java | 66 +++++++-------- .../framework/metrics/impl/Reduce.java | 68 ++++++++-------- .../metrics/BinaryCrossentropyTest.java | 10 +-- .../metrics/CategoricalCrossentropyTest.java | 10 +-- .../metrics/CategoricalHingeTest.java | 4 +- .../metrics/CosineSimilarityTest.java | 6 +- .../framework/metrics/HingeTest.java | 4 +- .../framework/metrics/KLDivergenceTest.java | 4 +- .../framework/metrics/LogCoshErrorTest.java | 4 +- .../metrics/MeanAbsoluteErrorTest.java | 4 +- .../MeanAbsolutePercentageErrorTest.java | 4 +- .../metrics/MeanSquaredErrorTest.java | 4 +- .../MeanSquaredLogarithmicErrorTest.java | 4 +- .../framework/metrics/PoissonTest.java | 4 +- .../SparseCategoricalCrossentropyTest.java | 6 +- .../framework/metrics/SquaredHingeTest.java | 4 +- 52 files changed, 293 insertions(+), 354 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index c7edfcca24e..3417c07372a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -202,13 +202,12 @@ public BinaryCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 363291fa5cc..035af9589ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -242,13 +242,12 @@ public CategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index f592c19f8bb..4e9133d8835 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 137c7025c04..0a18d93caf3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -22,12 +22,13 @@ /** * Computes the cosine similarity between labels and predictions. * - *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 - * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to - * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you - * try to maximize the proximity between predictions and targets. If either labels or predictions is - * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and - * targets. + *

Note that it is a number between -1 and 1. When it is a negative + * number between -1 and 0, 0 indicates orthogonality and + * values closer to -1indicate greater similarity. The values closer to 1 + * indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels or + * predictions is a zero vector, cosine similarity will be 0 regardless of + * the proximity between predictions and targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) * @@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; - private final int axis; + private final int[] axis; /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis @@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) { this(tf, null, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a + * Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, int[] axis) { + + this(tf, null, axis, DEFAULT_REDUCTION); + } /** * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} @@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) { this(tf, name, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, String name, int[] axis) { + + this(tf, name, axis, DEFAULT_REDUCTION); + } + /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an * axis of {@link #DEFAULT_AXIS} @@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) { */ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + this(tf, null, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + this(tf, null, axis, reduction); } @@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { * @param reduction Type of Reduction to apply to the loss. */ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { + this(tf, name, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { super(tf, name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + losses = tf.math.neg(losses); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 88b4a7aa056..37e7e367b9b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -121,8 +121,8 @@ public Hinge(Ops tf, String name, Reduction reduction) { * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") Operand tLabels = predictions.type() == labels.type() ? (Operand)labels : cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 6d3e3f0c2ac..e8de632eb09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -130,8 +130,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 8cf3db8d518..b3c0206b409 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 1669669a768..812260d9881 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -105,8 +105,8 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index ae33d5dfa37..0f9b183f38c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -62,10 +62,9 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels * @param predictions the predictions * @param The data type of the predictions and loss. - * @param The data type of the labels. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return call(labels, predictions, null); } @@ -82,11 +81,10 @@ public Operand call(Operand labels, * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 0d25bd5e7e2..a5ced3d1df8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -48,11 +48,10 @@ public class Losses { * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute error */ - public static Operand meanAbsoluteError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsoluteError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -70,11 +69,10 @@ public static Operand meanAbsoluteErro * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared error */ - public static Operand meanSquaredError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -91,11 +89,10 @@ public static Operand meanSquaredError * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute percentage error */ - public static Operand meanAbsolutePercentageError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsolutePercentageError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -118,11 +115,10 @@ public static Operand meanAbsolutePerc * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared logarithmic percentage error */ - public static Operand meanSquaredLogarithmicError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredLogarithmicError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -152,8 +148,8 @@ public static Operand meanSquaredLogar * @param the data type of the predictions and labels * @return the binary crossentropy loss. */ - public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + public static Operand binaryCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -181,7 +177,7 @@ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible + /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { // TODO - this does not work // TODO output = backtrackIdentity(output); @@ -225,9 +221,9 @@ private static Operand binaryCrossentropyHelper( * @param the data type of the predictions and labels * @return the categorical crossentropy loss. */ - public static Operand categoricalCrossentropy( + public static Operand categoricalCrossentropy( Ops tf, - Operand labels, + Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing, @@ -283,8 +279,8 @@ public static Operand categoricalCross * @param the data type of the predictions and labels * @return the categorical hinge loss */ - public static Operand categoricalHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand categoricalHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -329,8 +325,8 @@ public static Operand categoricalHinge * @param the data type of the predictions and labels * @return the cosine similarity loss */ - public static Operand cosineSimilarity( - Ops tf, Operand labels, Operand predictions, int axis) { + public static Operand cosineSimilarity( + Ops tf, Operand labels, Operand predictions, int[] axis) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -339,8 +335,7 @@ public static Operand cosineSimilarity tLabels = l2Normalize(tf, tLabels, axis); predictions = l2Normalize(tf, predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); - Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.neg(sum); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } /** @@ -355,8 +350,8 @@ public static Operand cosineSimilarity * @param the data type of the predictions and labels * @return the hinge loss */ - public static Operand hinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand hinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -391,8 +386,8 @@ public static Operand hinge( * @param the data type of the predictions and labels * @return the Huber loss */ - public static Operand huber( - Ops tf, Operand labels, Operand predictions, float delta) { + public static Operand huber( + Ops tf, Operand labels, Operand predictions, float delta) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -422,8 +417,8 @@ public static Operand huber( * @see Kullback?Leibler * divergence */ - public static Operand kullbackLeiblerDivergence( - Ops tf, Operand labels, Operand predictions) { + public static Operand kullbackLeiblerDivergence( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -452,8 +447,8 @@ public static Operand kullbackLeiblerD * @param the data type of the predictions and labels * @return the hyperbolic cosine divergence loss */ - public static Operand logCosh( - Ops tf, Operand labels, Operand predictions) { + public static Operand logCosh( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -480,8 +475,8 @@ public static Operand logCosh( * @param the data type of the predictions and labels * @return the Poisson loss */ - public static Operand poisson( - Ops tf, Operand labels, Operand predictions) { + public static Operand poisson( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -507,8 +502,8 @@ public static Operand poisson( * @param the data type of the predictions and labels * @return the sparse categorical crossentropy loss */ - public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + public static Operand sparseCategoricalCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -553,7 +548,7 @@ public static Operand sparseCategorica int labelsRank = labelsShape.numDimensions(); boolean updateShape = labelsRank != predictionsRank - 1; - if (updateShape) { // TODO check to see if this is right + if (updateShape) { Shape newShape = labelsShape.take(labelsRank - 1); iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = @@ -584,8 +579,8 @@ public static Operand sparseCategorica * @param the data type of the predictions and labels * @return the squared hinge loss */ - public static Operand squaredHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand squaredHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -651,7 +646,7 @@ private static Operand smoothCategoricalLabels( * @param axis Dimension along which to normalize. * @return the normalized values based on L2 norm */ - public static Operand l2Normalize(Ops tf, Operand x, int axis) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index a2d5d5f8efc..594de1e1448 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -95,8 +95,8 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 49133df610b..275a2e136a0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -95,8 +95,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 2a6c2be885e..31df3e70e0b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -95,8 +95,8 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 2604e226b81..bef990d22bc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -95,8 +95,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index c43be4f2821..9cf38aa0380 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -104,8 +104,8 @@ public Poisson(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.poisson(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index ea765e6f8fd..3ec33113e89 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -190,13 +190,12 @@ public SparseCategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 4ad4c1c726c..968624db202 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -117,13 +117,12 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") Operand tLabels = predictions.type() == labels.type() ? (Operand)labels : cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 651a6fac0b0..abd2dcbbf40 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,17 +21,18 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * *

This is the crossentropy metric class to be used when there are only two label classes (0 and * 1). * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -60,7 +61,7 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index c330ea88eaa..be43f34b92e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -30,11 +30,10 @@ * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class CategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -48,7 +47,8 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to + * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +68,8 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -98,7 +99,7 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalCrossentropy( getTF(), labels, predictions, fromLogits, labelSmoothing, axis); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 2741a36edb6..c70f2d8643b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -24,10 +24,9 @@ /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper +public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 458de092bec..5abbd095420 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; @@ -23,10 +24,9 @@ /** * A metric that computes the cosine similarity metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; @@ -76,8 +76,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Metrics.cosineProximity(getTF(), labels, predictions, axis); + return Losses.cosineSimilarity(getTF(), labels, predictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index baf9ad8ab7d..e0aced6fa3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -24,10 +24,9 @@ /** * A metric that computes the hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper +public class Hinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.hinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index efcbbcbb7f0..fa09f2784b5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -25,10 +25,9 @@ * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** @@ -47,7 +46,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 3df8505d54b..c43551a6948 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -25,10 +25,9 @@ * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper +public class LogCoshError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -47,7 +46,7 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.logCosh(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index de1f5a5629e..8902b329bcc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -21,10 +21,9 @@ /** * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * - * @param The data type for the metric values * @param The data type for the metric result */ -public class Mean extends Reduce { +public class Mean extends Reduce { /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index e27676932ff..d343ec77ab0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -24,10 +24,9 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper +public class MeanAbsoluteError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +45,7 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsoluteError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 84fa9b627b2..dd7d151260b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -24,11 +24,10 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c7edd6ebe93..c2bef576b30 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -24,10 +24,9 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper +public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 199b6e0e114..c1cf4ca6c9a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -24,11 +24,10 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index bbb2aa73da2..8ab21c58218 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -25,10 +25,9 @@ /** * Base class for Metrics * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Metric { +public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; @@ -75,10 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state - * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -90,13 +89,13 @@ public List updateStateList(Operand values, Operand the data type for the labels - * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -105,10 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -119,12 +118,12 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the labels - * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, + Operand predictions, + Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -149,10 +148,9 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies - * @param the data type for the sampleWeights. */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 0169bc6b8bc..95b74bf1eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -17,7 +17,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; @@ -46,89 +45,14 @@ public class Metrics { * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. * @param the data type for the predictions and results - * @param the data type ofr the labels. * @return the Operand for the Top K categorical accuracy value. */ - public static Operand topKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, long k) { + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); return CastHelper.cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } - - /** - * Computes the cosine similarity between labels and predictions. - * - * @param tf the TensorFlow Ops - * @param labels The ground truth values. - * @param predictions The prediction values. - * @param axes The dimensions along which the cosine similarity is computed. - * @param the data type for the labels - * @param the data type for the predictions and result - * @return Cosine similarity value. - */ - public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axes) { - Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axes); - - Operand predictionsNorm = l2Normalize(tf, predictions, axes); - Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); - } - - /** - * Normalizes along dimension axis using an L2 norm with an epsilon of {@link - * #L2_NORM_EPSILON}. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param The data type for x. - * @return the normalized values of x. - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { - return l2Normalize(tf, x, axes, L2_NORM_EPSILON); - } - - /** - * Normalizes along dimension axis using an L2 norm. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the - * divisor if norm < sqrt(epsilon). - * @param The data type for the values. - * @return the normalized values of x. - */ - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes, float epsilon) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); - Operand y = - tf.math.rsqrt( - tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); - return tf.math.mul(x, y); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 75a2031fbb5..af50b103a60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,10 +24,10 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param the data type for the predictions. + * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper +public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +46,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 2e01f722de6..a0c016b70b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -24,12 +24,11 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. - * - * @param the data type for the predictions. + *\ * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final int axis; @@ -55,7 +54,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 430dbbcc229..bd331a85eda 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -24,10 +24,9 @@ /** * A metric that computes the squared hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index b7b87d313aa..70bb8133698 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -29,8 +29,7 @@ public interface LossMetric { * * @param labels the truth values or labels * @param predictions the predictions - * @param The data type of the labels. * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 17c209a8fed..9a532a0294f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -17,13 +17,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.metrics.Mean; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. @@ -32,10 +33,9 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class MeanMetricWrapper extends Mean { +public class MeanMetricWrapper extends Mean { /** The loss function interface */ protected LossMetric loss; @@ -85,22 +85,21 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList( - CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList(cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index ad8ff58e417..6cc089fce6d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -21,12 +21,10 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; @@ -57,13 +55,13 @@ public class MetricsHelper { * @param values the values to which weights are applied. * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values - * @param the type of Operand + * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") - public static Op assertBroadcastable( - Ops tf, Operand sampleWeights, Operand values) { + public static Op assertBroadcastable( + Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match @@ -129,7 +127,7 @@ public static Op assertBroadcastable( // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. - Operand reshapedWeights = + Operand reshapedWeights = tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); weightsShape = tf.shape(reshapedWeights); weightsRank = tf.rank(reshapedWeights); @@ -237,11 +235,10 @@ public static Operand mean(Ops tf, Operand x) { * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axes. * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axes) { + public static Operand mean( + Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } @@ -257,31 +254,27 @@ public static Operand mean( * @param the type of the operand * @return the mean of elements of x. */ - public static Operand mean( - Ops tf, Operand x, boolean keepDims) { + public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } - - /** * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @param the data type of the Operand - * @param the data type of the axes * @return the mean of elements of x. */ - - public static Operand mean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { if (axes == null) { - axes = (Operand) allAxes(tf, x); + axes = allAxes(tf, x); } return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } @@ -294,7 +287,7 @@ public static Operand mean( * @param x the Operand used to calculate the mean * @return the mean of the operand containing floating point numbers */ - public static Operand booleanMean(Ops tf, Operand x) { + public static Operand booleanMean(Ops tf, Operand x) { return booleanMean(tf, x, null, false); } @@ -305,11 +298,10 @@ public static Operand booleanMean(Ops tf, Operand x) { * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param the type of the axes. * @return the mean of the operand, along the specified axes, containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x,Operand axes) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes) { return booleanMean(tf, x, axes, false); } @@ -317,14 +309,13 @@ public static Operand booleanMean( * Calculates the mean of the boolean operand, alongside all axes. * * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, boolean keepDims) { + public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); } @@ -333,16 +324,15 @@ public static Operand booleanMean( * * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { Operand xf = cast(tf, x, TFloat64.class); return mean(tf, xf, axes, keepDims); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 8e48cb4e573..2a26967b9f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -19,7 +19,6 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -29,13 +28,14 @@ import java.util.ArrayList; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Encapsulates metrics that perform a reduce operation on the metric values. * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Reduce extends Metric { +public abstract class Reduce extends Metric { public static final String TOTAL = "total"; public static final String COUNT = "count"; protected final MetricReduction reduction; @@ -45,8 +45,10 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values. - * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + /** + * the variable that holds the count of the metric values. For {@link + * MetricReduction#WEIGHTED_MEAN}, this count may be weighted + */ protected Variable count; /** @@ -95,12 +97,10 @@ private void setupVars() { public Op resetStates() { List controls = new ArrayList<>(); if (total != null) { - controls.add( - getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + controls.add(getTF().assign(total, cast(getTF(), getTF().constant(0), total.type()))); } if (count != null) { - controls.add( - getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + controls.add(getTF().assign(count, cast(getTF(), getTF().constant(0), count.type()))); } return getTF().withControlDependencies(controls).noOp(); } @@ -115,67 +115,67 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); } + Ops tf = getTF(); List updateOperations = new ArrayList<>(); // cast everything to match the variables - Operand lSampleWeights = null; - Operand lValues = values; + Operand tSampleWeights = null; + Operand tValues = cast(tf, values, getResultType()); if (sampleWeights != null) { - lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); - lValues = tuple.getTarget(); - lSampleWeights = tuple.getSampleWeights(); + tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); try { - lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples - int valuesRank = lValues.shape().numDimensions(); - int weightsRank = lSampleWeights.shape().numDimensions(); + int valuesRank = tValues.shape().numDimensions(); + int weightsRank = tSampleWeights.shape().numDimensions(); int numAxes = Math.min(0, valuesRank - weightsRank); if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + tValues = getTF().reduceSum(tValues, getTF().constant(axes)); } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + tValues = getTF().math.mean(tValues, getTF().constant(axes)); } } } - lValues = getTF().math.mul(lValues, lSampleWeights); + tValues = getTF().math.mul(tValues, tSampleWeights); } - Operand weightedValueSum = - getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); break; case WEIGHTED_MEAN: - if (lSampleWeights == null) { - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + if (tSampleWeights == null) { + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); } else { numValues = - CastHelper.cast( + cast( getTF(), getTF() - .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), resultType); } break; @@ -202,7 +202,7 @@ public Operand result() { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + fResult = getTF().math.divNoNan(total, cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 7ceedded018..be46bb5c282 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -32,7 +32,7 @@ class BinaryCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -77,7 +77,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 0, 1, 0}; @@ -102,7 +102,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -128,7 +128,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java index 2b4a1d75467..34fc3eef884 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -31,7 +31,7 @@ class CategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -104,7 +104,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 0, 0, 1}; @@ -129,7 +129,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java index 87248d95e48..78b25a21b60 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -31,7 +31,7 @@ class CategoricalHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { @@ -64,7 +64,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index a9721ef2f8f..18410416c42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -31,7 +31,7 @@ class CosineSimilarityTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -80,7 +80,7 @@ public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); int axis = 1; - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 6af5fed4889..a9bd5fac76e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,7 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 28020c0fa1c..267578a492c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -31,7 +31,7 @@ class KLDivergenceTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java index 31c043e0473..1b5b8fb7d49 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -32,7 +32,7 @@ class LogCoshErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -56,7 +56,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java index 73241ecbe9f..984895f2ad9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -32,7 +32,7 @@ class MeanAbsoluteErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -74,7 +74,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java index 4c92844b217..0b9e7f6b538 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -34,7 +34,7 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -76,7 +76,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java index 0b760213015..e42052a9ef1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -33,7 +33,7 @@ class MeanSquaredErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); @@ -70,7 +70,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java index 098a5cb9725..e68d63b8778 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -32,7 +32,7 @@ class MeanSquaredLogarithmicErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -69,7 +69,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index cf3c3e44719..75d9ef93168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -32,7 +32,7 @@ class PoissonTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 87af1bd8448..8e1aaea0a8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -32,7 +32,7 @@ class SparseCategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -56,7 +56,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -105,7 +105,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java index e3376c224f3..2c80b3451ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -32,7 +32,7 @@ class SquaredHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { @@ -61,7 +61,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { From d7f7e4c5871b3a418e9e2335b95a3c9e0009213b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 14:42:48 -0500 Subject: [PATCH 2/6] Reformat code --- .../annotations/org/tensorflow/op/Ops.java | 6 +- .../losses/CategoricalCrossentropy.java | 36 ++++++------ .../framework/losses/CategoricalHinge.java | 4 +- .../tensorflow/framework/losses/Hinge.java | 20 ++++--- .../tensorflow/framework/losses/Huber.java | 2 +- .../framework/losses/KLDivergence.java | 2 +- .../tensorflow/framework/losses/LogCosh.java | 2 +- .../org/tensorflow/framework/losses/Loss.java | 5 +- .../tensorflow/framework/losses/Losses.java | 28 +++++++--- .../framework/losses/MeanAbsoluteError.java | 2 +- .../losses/MeanAbsolutePercentageError.java | 2 +- .../framework/losses/MeanSquaredError.java | 2 +- .../losses/MeanSquaredLogarithmicError.java | 2 +- .../losses/SparseCategoricalCrossentropy.java | 34 +++++++----- .../framework/losses/SquaredHinge.java | 18 +++--- .../framework/losses/impl/LossTuple.java | 2 +- .../framework/losses/impl/LossesHelper.java | 26 ++++----- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../framework/metrics/CategoricalHinge.java | 4 +- .../framework/metrics/CosineSimilarity.java | 2 +- .../tensorflow/framework/metrics/Hinge.java | 3 +- .../framework/metrics/KLDivergence.java | 5 +- .../framework/metrics/LogCoshError.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 2 +- .../metrics/MeanAbsolutePercentageError.java | 6 +- .../framework/metrics/MeanSquaredError.java | 4 +- .../metrics/MeanSquaredLogarithmicError.java | 6 +- .../tensorflow/framework/metrics/Poisson.java | 6 +- .../SparseCategoricalCrossentropy.java | 13 +++-- .../framework/metrics/SquaredHinge.java | 5 +- .../framework/metrics/impl/LossMetric.java | 2 +- .../framework/metrics/impl/SetsOps.java | 55 ++++++++++--------- .../framework/metrics/HingeTest.java | 6 +- .../framework/metrics/PoissonTest.java | 3 +- .../SparseCategoricalCrossentropyTest.java | 2 +- 35 files changed, 174 insertions(+), 155 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..007ee9d0d42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 035af9589ae..5aac163c1e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. x=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means + * that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -199,13 +202,14 @@ public CategoricalCrossentropy( * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" - * and axis=1 corresponds to data format "Channels First". - * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} + * and axis=1 corresponds to data format "Channels First". {@link + * Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 4e9133d8835..73837ed1756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1.

+ *

labels values are expected to be 0 or 1. * *

Standalone usage: * @@ -100,7 +100,7 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 37e7e367b9b..db3569441ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,15 +18,16 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0)

. + *

loss = maximum(1 - labels * predictions, 0). * - *

labels values are expected to be -1 or 1. - * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -124,13 +125,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index e8de632eb09..665a9ac157d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -131,7 +131,7 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index b3c0206b409..2aa1f72092b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -100,7 +100,7 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 812260d9881..78325713e3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -106,7 +106,7 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 0f9b183f38c..cdd35d28aba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -64,7 +64,8 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param The data type of the predictions and loss. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { return call(labels, predictions, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index a5ced3d1df8..2222ebb41f8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -102,8 +102,10 @@ public static Operand meanAbsolutePercentageError( tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum( + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul( + cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -149,7 +151,11 @@ public static Operand meanSquaredLogarithmicError( * @return the binary crossentropy loss. */ public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -214,9 +220,10 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. @@ -503,7 +510,11 @@ public static Operand poisson( * @return the sparse categorical crossentropy loss */ public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -650,8 +661,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 594de1e1448..03a3cf70110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -96,7 +96,7 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 275a2e136a0..6c5242df4f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -96,7 +96,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 31df3e70e0b..f975db55c44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -96,7 +96,7 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index bef990d22bc..11b8e157e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -96,7 +96,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 3ec33113e89..d04cc67d5d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -79,7 +80,8 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ @@ -88,8 +90,8 @@ public SparseCategoricalCrossentropy(Ops tf) { } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -122,8 +124,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and - * fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -135,7 +137,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -176,9 +179,10 @@ public SparseCategoricalCrossentropy( /** * Generates an Operand the calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. @@ -200,12 +204,12 @@ public Operand call( if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesHelper.rangeCheck( - getTF(), - "predictions range check [0-1]", - predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + LossesHelper.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 968624db202..dadbdb3b95e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -25,8 +26,8 @@ * *

loss = square(maximum(1 - labels * predictions, 0)) * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be - * converted to -1 or 1. + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -107,7 +108,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor @@ -124,13 +125,16 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 2104937a979..f811549fbca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return labels, target, and sampleWeights + * A helper class for loss methods to return labels, target, and sampleWeights * * @param the data type of the LossTuple entries. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 10067db91ba..66bdd839f09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -32,8 +32,9 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to - * TensorFlow Java. These methods should not be used outside of the losses and metrics packages. + * These are helper methods for Losses and Metrics and will be module private when Java modularity + * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics + * packages. */ public class LossesHelper { @@ -42,10 +43,10 @@ public class LossesHelper { * *

    *
  1. Squeezes last dim of predictions or labels if their rank - * differs by 1 (using {@link #removeSqueezableDimensions}).
  2. + * differs by 1 (using {@link #removeSqueezableDimensions}). *
  3. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from * the new rank of predictions. If sampleWeight is scalar, it is - * kept scalar.
  4. + * kept scalar. *
* * @param tf the TensorFlow Ops @@ -77,12 +78,13 @@ public static LossTuple squeezeOrExpandDimensions( * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * * prediction. - * @return LossTuple of predictions, labels and sampleWeight. - * Each of them possibly has the last dimension squeezed, sampleWeight could be - * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are - * returned. + * @return LossTuple of predictions, labels and sampleWeight + * . Each of them possibly has the last dimension squeezed, sampleWeight + * could be extended by one dimension. If sampleWeight is null, only the possibly + * shape modified predictions and labels are returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -298,8 +300,7 @@ private static Operand reduceWeightedLoss( public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.type())); + return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); } /** @@ -383,8 +384,7 @@ public static Operand rangeCheck( */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { - Operand flatValues = - tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); + Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().shape().size(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index abd2dcbbf40..d8bb2a41116 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,8 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -31,8 +29,8 @@ * * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -42,7 +40,8 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index c70f2d8643b..4800fc43c49 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result */ -public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper +public class CategoricalHinge extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 5abbd095420..3ae67072955 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> +public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index e0aced6fa3e..3b84b81e071 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper - implements LossMetric { +public class Hinge extends MeanMetricWrapper implements LossMetric { /** * Creates a Hinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index fa09f2784b5..f631f562e1d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper - implements LossMetric { +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** * Creates a KLDivergence metric @@ -46,7 +45,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index c43551a6948..046937e228b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper< T> - implements LossMetric { +public class LogCoshError extends MeanMetricWrapper implements LossMetric { /** * Creates a LogCoshError metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index d343ec77ab0..977f61648a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper< T> +public class MeanAbsoluteError extends MeanMetricWrapper implements LossMetric { /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index dd7d151260b..bad5255969a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c2bef576b30..5b0d9ec43b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> +public class MeanSquaredError extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index c1cf4ca6c9a..35044fee956 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index af50b103a60..700099d3375 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,11 +24,9 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param The data type for the metric result. */ -public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> - implements LossMetric { +public class Poisson extends MeanMetricWrapper implements LossMetric { /** * Creates a Poisson metric @@ -46,7 +44,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a0c016b70b3..aa7ca316378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -23,12 +23,12 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. - *\ + * predicted labels. \ + * * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; @@ -38,7 +38,8 @@ public class SparseCategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -54,7 +55,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index bd331a85eda..01f4a403f84 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** * Creates a SquaredHinge metric @@ -45,7 +44,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 70bb8133698..037d634cd4a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 1841c7ee238..467dea19b57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -25,33 +25,6 @@ /** Implementation of set operations */ public class SetsOps { - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } - /** * Computes set difference of elements in last dimension of a and b with * aMinusB set to true. @@ -69,6 +42,7 @@ public String getSetOperation() { public static Operand difference(Ops tf, Operand a, Operand b) { return difference(tf, a, b, true); } + /** * Computes set difference of elements in last dimension of a and b. * @@ -143,4 +117,31 @@ public static Operand setOperation( setOperationResult.resultValues(), cast(tf, tf.constant(0), a.type())); } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index a9bd5fac76e..90531d21fde 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,8 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -55,8 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index 75d9ef93168..5631bac15ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 8e1aaea0a8f..0aece8c8ac9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); session.run(instance.resetStates()); From 6b4149ce0036fca5a3dfe33f283b1e6a980a7f9d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 18:40:00 -0500 Subject: [PATCH 3/6] Change order of TrainOps and QuantiQuantizationOps. For some reason, when I build it reverses these 2 from master's version. --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 007ee9d0d42..84736ada6a5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** From e486a9038eb22ae05b3e33ee63f8c371f0b509c6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Feb 2021 15:32:07 -0500 Subject: [PATCH 4/6] Fix LossMetric to change abstract "call" method to use gneric parameter for predictions instead of . --- .../framework/metrics/BinaryCrossentropy.java | 8 ++++++-- .../framework/metrics/CategoricalCrossentropy.java | 8 ++++++-- .../framework/metrics/CategoricalHinge.java | 8 ++++++-- .../framework/metrics/CosineSimilarity.java | 11 ++++++++--- .../java/org/tensorflow/framework/metrics/Hinge.java | 8 ++++++-- .../tensorflow/framework/metrics/KLDivergence.java | 8 ++++++-- .../tensorflow/framework/metrics/LogCoshError.java | 8 ++++++-- .../framework/metrics/MeanAbsoluteError.java | 8 ++++++-- .../metrics/MeanAbsolutePercentageError.java | 8 ++++++-- .../framework/metrics/MeanSquaredError.java | 8 ++++++-- .../metrics/MeanSquaredLogarithmicError.java | 8 ++++++-- .../org/tensorflow/framework/metrics/Poisson.java | 8 ++++++-- .../metrics/SparseCategoricalCrossentropy.java | 8 ++++++-- .../tensorflow/framework/metrics/SquaredHinge.java | 8 ++++++-- .../tensorflow/framework/metrics/impl/LossMetric.java | 2 +- 15 files changed, 87 insertions(+), 30 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index d8bb2a41116..263b8a789ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -60,7 +62,9 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index be43f34b92e..cbe0127295f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. @@ -99,8 +101,10 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( - getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4800fc43c49..ff814ae6ed3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * @@ -45,7 +47,9 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.categoricalHinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.categoricalHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 3ae67072955..d64136d0d90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the cosine similarity metric between labels and predictions. * @@ -76,8 +78,11 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Losses.cosineSimilarity(getTF(), labels, predictions, axis); + public Operand call(Operand labels, Operand predictions) { + // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, + // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 3b84b81e071..7a37cbeddbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the hinge loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.hinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.hinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index f631f562e1d..3027bb2f460 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. @@ -45,7 +47,9 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 046937e228b..ca84e651988 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. @@ -45,7 +47,9 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.logCosh(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.logCosh(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 977f61648a1..c91cb0df1ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsoluteError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index bad5255969a..6cc96a4fb88 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 5b0d9ec43b3..1fce9998270 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 35044fee956..900359db88b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 700099d3375..3572c155b96 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the poisson loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.poisson(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.poisson(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index aa7ca316378..a74f575a4a8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. \ @@ -55,7 +57,9 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 01f4a403f84..6bee2ccf8e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the squared hinge loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.squaredHinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.squaredHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 037d634cd4a..1fb3d3bb580 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } From c7115323dce138c6ed6ced16c3aaf435e8cc046e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 6 Feb 2021 16:17:11 -0500 Subject: [PATCH 5/6] Reformat code, fix javadoc --- .../annotations/org/tensorflow/op/Ops.java | 6 ++-- .../framework/initializers/Glorot.java | 1 - .../tensorflow/framework/initializers/He.java | 1 - .../tensorflow/framework/losses/Hinge.java | 7 +--- .../tensorflow/framework/losses/Huber.java | 4 ++- .../tensorflow/framework/losses/LogCosh.java | 3 +- .../tensorflow/framework/losses/Losses.java | 1 + .../tensorflow/framework/losses/Poisson.java | 3 +- .../framework/losses/impl/LossesHelper.java | 7 +++- .../framework/metrics/BinaryCrossentropy.java | 3 +- .../metrics/CategoricalCrossentropy.java | 3 +- .../framework/metrics/CategoricalHinge.java | 3 +- .../framework/metrics/CosineSimilarity.java | 3 +- .../tensorflow/framework/metrics/Hinge.java | 3 +- .../framework/metrics/KLDivergence.java | 3 +- .../framework/metrics/LogCoshError.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 3 +- .../metrics/MeanAbsolutePercentageError.java | 3 +- .../framework/metrics/MeanSquaredError.java | 3 +- .../metrics/MeanSquaredLogarithmicError.java | 3 +- .../tensorflow/framework/metrics/Metric.java | 6 +++- .../tensorflow/framework/metrics/Poisson.java | 3 +- .../SparseCategoricalCrossentropy.java | 3 +- .../framework/metrics/SquaredHinge.java | 3 +- .../framework/metrics/impl/MetricsHelper.java | 2 ++ .../framework/optimizers/AdaDelta.java | 34 +++++++++---------- .../framework/optimizers/AdaGrad.java | 8 ++--- .../framework/optimizers/AdaGradDA.java | 2 +- .../framework/optimizers/Adamax.java | 4 +-- .../tensorflow/framework/optimizers/Ftrl.java | 10 +++--- .../framework/optimizers/Nadam.java | 5 ++- .../framework/optimizers/Optimizer.java | 20 ++++++----- .../framework/optimizers/RMSProp.java | 27 +++++++-------- .../framework/utils/CastHelper.java | 2 +- .../framework/utils/ShapeUtils.java | 28 ++++++++------- 35 files changed, 124 insertions(+), 99 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..007ee9d0d42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 290e4e80b57..894bd073758 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -62,7 +62,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see VarianceScaling.Distribution diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 9b1a0887af0..3a91b72b0d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -57,7 +57,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - @SuppressWarnings("unchecked") - Operand tLabels = - predictions.type() == labels.type() - ? (Operand) labels - : cast(tf, labels, predictions.type()); + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( getTF(), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 665a9ac157d..b1aee1b0656 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -89,6 +89,7 @@ public Huber(Ops tf) { * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Huber(Ops tf, String name) { this(tf, name, DELTA_DEFAULT, Reduction.AUTO); @@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) { * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Huber(Ops tf, String name, Reduction reduction) { @@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) { * Creates a Huber Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 78325713e3e..a11d582e527 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -77,6 +77,7 @@ public LogCosh(Ops tf) { * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public LogCosh(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -96,7 +97,7 @@ public LogCosh(Ops tf, Reduction reduction) { * Creates a LogCosh Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public LogCosh(Ops tf, String name, Reduction reduction) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 2222ebb41f8..9aa94cf7fcf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -655,6 +655,7 @@ private static Operand smoothCategoricalLabels( * @param tf The TensorFlow Ops * @param x the input * @param axis Dimension along which to normalize. + * @param the data type for the input and the result * @return the normalized values based on L2 norm */ public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 9cf38aa0380..78324acf8a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -76,6 +76,7 @@ public Poisson(Ops tf) { * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Poisson(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -95,7 +96,7 @@ public Poisson(Ops tf, Reduction reduction) { * Creates a Poisson Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Poisson(Ops tf, String name, Reduction reduction) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 66bdd839f09..f6b0de71b0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -53,6 +53,7 @@ public class LossesHelper { * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . + * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, @@ -81,6 +82,7 @@ public static LossTuple squeezeOrExpandDimensions( * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. + * @param the data type for the labels, predictions and result * @return LossTuple of predictions, labels and sampleWeight * . Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, only the possibly @@ -180,6 +182,7 @@ private static Operand maybeExpandWeights( * @param labels Label values, a Tensor whose dimensions match predictions * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -195,6 +198,7 @@ public static LossTuple removeSqueezableDimensions( *
. * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -218,7 +222,8 @@ public static LossTuple removeSqueezableDimensions( } // Use dynamic rank. - // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 263b8a789ed..48ee244eafb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -62,7 +62,8 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index cbe0127295f..b22e5415f79 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -101,7 +101,8 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index ff814ae6ed3..4266cc487c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -47,7 +47,8 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalHinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index d64136d0d90..840f255c5ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -78,7 +78,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity Operand tLabels = cast(getTF(), labels, getResultType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 7a37cbeddbe..46ccd2859ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -46,7 +46,8 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.hinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 3027bb2f460..9ffcd6189f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -47,7 +47,8 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index ca84e651988..59e24f57110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -47,7 +47,8 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.logCosh(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index c91cb0df1ef..1cc6d0b6f99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -47,7 +47,8 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 6cc96a4fb88..8c6720b58f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -47,7 +47,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 1fce9998270..3c4c79d39ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -47,7 +47,8 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanSquaredError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 900359db88b..d525bb76648 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -47,7 +47,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 8ab21c58218..468919e696d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -184,7 +184,11 @@ public String getName() { return name; } - /** The random number generator seed value */ + /** + * Gets the random number generator seed value + * + * @return the random number generator seed value + */ public long getSeed() { return seed; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 3572c155b96..422fd4808ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -46,7 +46,8 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.poisson(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a74f575a4a8..9949f0c6b60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -57,7 +57,8 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 6bee2ccf8e4..19b3b1d0ac4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -46,7 +46,8 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.squaredHinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 6cc089fce6d..8a352322f52 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -308,6 +308,7 @@ public static Operand booleanMean( /** * Calculates the mean of the boolean operand, alongside all axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is * false, the rank of the tensor is reduced by 1 for each entry in axes @@ -322,6 +323,7 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke /** * Calculates the mean of the boolean operand, alongside the specified axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 822eb490f22..aadbfeea54b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -31,29 +31,29 @@ * learning rate per dimension to address two drawbacks: * *
    - *
  • the continual decay of learning rates throughout training - *
  • the need for a manually selected global learning rate + *
  • the continual decay of learning rates throughout training + *
  • the need for a manually selected global learning rate *
* - *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a - * moving window of gradient updates, instead of accumulating all past gradients. This way, - * Adadelta continues learning even when many updates have been done. Compared to Adagrad, in - * the original version of Adadelta you don't have to set an initial learning rate. In this - * version, initial learning rate can be set, as in most other optimizers. + *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving + * window of gradient updates, instead of accumulating all past gradients. This way, Adadelta + * continues learning even when many updates have been done. Compared to Adagrad, in the original + * version of Adadelta you don't have to set an initial learning rate. In this version, initial + * learning rate can be set, as in most other optimizers. * - *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes - * converge to 1 which is effectively a high learning rate which would cause divergence. This - * occurs only near the end of the training as gradients and step sizes are small, and the - * epsilon constant in the numerator and denominator dominate past gradients and parameter - * updates which converge the learning rate to 1. + *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes + * converge to 1 which is effectively a high learning rate which would cause divergence. This occurs + * only near the end of the training as gradients and step sizes are small, and the epsilon constant + * in the numerator and denominator dominate past gradients and parameter updates which converge the + * learning rate to 1. * - *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers - * was trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The - * epsilon used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following - * construction: new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); + *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers was + * trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The epsilon + * used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following construction: + * new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); * * @see Zeiler, M., 2012 ADADELTA: An Adaptive Learning - * Rate Method. + * Rate Method */ public class AdaDelta extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 08f5f18a9cd..2dd05ef31b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -31,10 +31,10 @@ * how frequently a parameter gets updated during training. The more updates a parameter receives, * the smaller the updates. * - *

- * - * @see Duchi, J, et al., 2011, Adaptive Subgradient Methods for Online Learning and Stochastic Optimization - * @see Duchi, J, et al., 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1. + * @see Duchi, J, et al., 2011, + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization + * @see Duchi, J, et al., + * 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1 */ public class AdaGrad extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index df624e41c4e..7114c33339f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -40,7 +40,7 @@ * networks as it will require careful initialization of the gradient accumulators for it to train. * * @see Duchi, J, et al., 2011, - * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization */ public class AdaGradDA extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index cd95bb3bd07..0ecc1ac1451 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -32,12 +32,10 @@ public class Adamax extends Optimizer { public static final float EPSILON_DEFAULT = 1e-07f; public static final float BETA_ONE_DEFAULT = 0.9f; public static final float BETA_TWO_DEFAULT = 0.999f; - - private float learningRate; private final float betaOne; private final float betaTwo; private final float epsilon; - + private final float learningRate; private Constant learningRateConst; private Constant epsilonConst; private Constant betaOneConst; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 66314d2ffe0..5d8c1478231 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -13,10 +13,11 @@ /** * Optimizer that implements the FTRL algorithm. * + *

This version has support for both online L2 (the L2 penalty given in the paper below) and + * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * * @see McMahan, et - * al., 2013, Algorithm 1 - *

This version has support for both online L2 (the L2 penalty given in the paper above) and - * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * al., 2013, Algorithm 1 */ public class Ftrl extends Optimizer { @@ -29,13 +30,12 @@ public class Ftrl extends Optimizer { public static final float L1STRENGTH_DEFAULT = 0.0f; public static final float L2STRENGTH_DEFAULT = 0.0f; public static final float L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT = 0.0f; - - private float learningRate; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; private final float l2RegularizationStrength; private final float l2ShrinkageRegularizationStrength; + private final float learningRate; /** * Creates a Ftrl Optimizer diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index f9900a8ee78..5b94b548c0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -24,8 +24,6 @@ */ public class Nadam extends Optimizer { - private static final float DECAY_BASE = 0.96f; - private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float EPSILON_DEFAULT = 1e-8f; public static final float BETA_ONE_DEFAULT = 0.9f; @@ -33,7 +31,8 @@ public class Nadam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; public static final String MOMENTUM = "momentum"; - + private static final float DECAY_BASE = 0.96f; + private static final float DECAY = 0.004f; /** The learning rate. */ private final float learningRate; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index fdf56da4a67..ed141831bbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -71,14 +71,6 @@ protected Optimizer(Graph graph, String name) { this.globals = new ArrayList<>(); } - /** - * Gets the Optimizer's Ops instance - * @return the Optimizer's Ops instance - */ - public final Ops getTF() { - return tf; - } - /** * Creates a name by combining a variable name and a slot name * @@ -90,6 +82,15 @@ public static String createName(Output variable, String slotNam return variable.op().name() + "-" + slotName; } + /** + * Gets the Optimizer's Ops instance + * + * @return the Optimizer's Ops instance + */ + public final Ops getTF() { + return tf; + } + /** * Minimizes the loss by updating the variables * @@ -299,7 +300,8 @@ private Options() {} * Sets the shared name * * @param sharedName If non-empty, this variable is named in the given bucket with this - * shared_name. Otherwise, the node name is used instead. + * sharedName. Otherwise, the node name is used instead. + * @return this options instance */ public Optimizer.Options sharedName(String sharedName) { this.sharedName = sharedName; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index b3729dc367f..e86e64971a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -27,17 +27,20 @@ /** * Optimizer that implements the RMSProp algorithm. * - *

The gist of RMSprop is to:

    - *
  • Maintain a moving (discounted) average of the square of gradients - *
  • Divide the gradient by the root of this average
+ *

The gist of RMSprop is to: * - *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + *

    + *
  • Maintain a moving (discounted) average of the square of gradients + *
  • Divide the gradient by the root of this average + *
* - *

The centered version additionally maintains a moving average of the gradients, and uses - * that average to estimate the variance. + *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + * + *

The centered version additionally maintains a moving average of the gradients, and uses that + * average to estimate the variance. * * @see Hinton G, - * et al. 2012, lecture notes that is inexplicably the canonical reference. + * et al. 2012, lecture notes, that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { @@ -165,24 +168,20 @@ protected void createSlots(List> variables) { } } - /** - * Creates the RMSProp Slots for Root Mean Squared (RMS), - * MOMENTUM, and Mean Gradient (MG) + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) * * @param v the variable to install in the slot * @param the datatype of the variable. */ private void createRMSPropSlot(Output v) { - Operand rmsInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); + Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); + Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MG, mgInitializer); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java index b0fe48967dd..1c027cb5ddf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -34,7 +34,7 @@ public class CastHelper { */ @SuppressWarnings("unchecked") public static Operand cast( - Ops tf, Operand value, Class requiredType) { + Ops tf, Operand value, Class requiredType) { return (value.type() == requiredType) ? (Operand) value : tf.dtypes.cast(value, requiredType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 4ca2c789f28..e730c79cfbf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -14,8 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; -import org.tensorflow.ndarray.NdArray; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; @@ -33,7 +34,9 @@ public class ShapeUtils { /** * Converts a shape operand to a Shape object * + * @param scope the TensorFlow scope * @param dims the Operand containing the shape values + * @param the date type for the shape dimensions. * @return a new Shape based on an Operand that contains dimensions */ public static Shape toShape(Scope scope, Operand dims) { @@ -45,8 +48,8 @@ public static Shape toShape(Scope scope, Operand dims) * Converts a TInt32 type Operand to a Java int array * * @param scope the TensorFlow scope - * @param dims the TInt32 Operand - * @return the int array + * @param dims the shape dimensions operand + * @return the int array of the dimensions */ public static int[] getIntArray(Scope scope, Operand dims) { long[] longDims = getLongArray(scope, dims); @@ -66,8 +69,8 @@ public static long[] getLongArray(Scope scope, Operand if (scope.env().isEager()) { return getLongArray(dims.asTensor()); } - try (Session session = new Session((Graph)scope.env()); - TIntegral tensor = (TIntegral)session.runner().fetch(dims).run().get(0)) { + try (Session session = new Session((Graph) scope.env()); + TIntegral tensor = (TIntegral) session.runner().fetch(dims).run().get(0)) { return getLongArray(tensor); } } @@ -76,20 +79,21 @@ public static long[] getLongArray(Scope scope, Operand * Converts a TInt32 or TInt64 to a java long array * * @param dims the dimension tensor + * @param the type of the dimensions, must either be TInt32 or TInt64 type * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(T dims) { List result = new ArrayList<>(); if (dims instanceof TInt32) { - ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); + ((TInt32) dims).scalars().forEach(s -> result.add((long) s.getInt())); } else if (dims instanceof TInt64) { - ((TInt64)dims).scalars().forEach(s -> result.add(s.getLong())); + ((TInt64) dims).scalars().forEach(s -> result.add(s.getLong())); } else if (dims instanceof TUint8) { - ((TUint8)dims).scalars().forEach(s -> result.add(s.getObject().longValue())); - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } + ((TUint8) dims).scalars().forEach(s -> result.add(s.getObject().longValue())); + } else { // shouldn't happen + throw new IllegalArgumentException("the data type must be an integer type"); + } return result.stream().mapToLong(i -> i).toArray(); } From 6b8412827e5cfe398a2dbfb92bc5a7c4cfad7f99 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Wed, 10 Feb 2021 23:22:24 -0500 Subject: [PATCH 6/6] Remove trailing character --- .../framework/metrics/SparseCategoricalCrossentropy.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 9949f0c6b60..e954169b2af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -25,7 +25,7 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. \ + * predicted labels. * * @param The data type for the metric result. */