From 9cc26757f102688789b58d32f18d6fd7e4941fc2 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 5 Oct 2020 10:18:55 -0400 Subject: [PATCH 01/26] Initial checkin to rebase to Initialziers to pick up changes to ndarry Shape --- .../org/tensorflow/framework/losses/Loss.java | 90 +++++++++++ .../tensorflow/framework/losses/Losses.java | 29 ++++ .../framework/losses/MeanAbsoluteError.java | 49 ++++++ .../framework/losses/Reduction.java | 19 +++ .../losses/impl/ConfusionMatrix.java | 81 ++++++++++ .../framework/losses/impl/LossesImpl.java | 142 ++++++++++++++++++ .../framework/losses/impl/Tuple.java | 107 +++++++++++++ 7 files changed, 517 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java new file mode 100644 index 00000000000..95d79507f90 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -0,0 +1,90 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +public abstract class Loss { + protected final Ops tf; + protected final Reduction reduction; + + /** + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + protected Loss(Ops tf) { + this(tf, null, Reduction.AUTO); + } + + /** + * Creates a Loss using a Loss Reduction of {@link Reduction#AUTO} + * + * @param tf the TensorFlow Ops + * @param name the name of this Loss + */ + protected Loss(Ops tf, String name) { + this(tf, name, Reduction.AUTO); + } + + /** + * Creates a Loss + * + * @param tf the TensorFlow Ops + * @param name the name of this loss + * @param reduction Type of Reduction to apply to the loss. + */ + protected Loss(Ops tf, String name, Reduction reduction) { + this.tf = name != null ? tf.withSubScope(name) : tf.withSubScope(getClass().getSimpleName()); + this.reduction = reduction; + } + + /** + * Calculates the loss + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the labels, predictions and loss. + * @return the loss + */ + public Operand call(Operand labels, Operand predictions) { + return call(labels, predictions, null); + } + + /** + * Calculates the loss + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sample_weight vector. If the shape of sample_weight is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * y_pred is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param The data type of the labels, predictions, sampleWeights and loss. + * @return the loss + */ + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the loss reduction + * + * @return the loss reduction + */ + public Reduction getReduction() { + return reduction; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java new file mode 100644 index 00000000000..65addd71c90 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -0,0 +1,29 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.Tuple; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +public class Losses { + + /** + * + * @param tf + * @param yTrue + * @param yPred + * @param + * @return + */ + public static Operand mean_absolute_error(Ops tf, Operand labels, Operand predictions) { + Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Tuple ops = squeezeOrExpandDimensions(tf, labels, predictions, null); + predictions = ops.getPredictions(); + tLabels = ops.getLabels(); + return mean(tf, tf.math.abs(tf.math.sub(yPred, yTrue)), tf.constant(-1)); + } + + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java new file mode 100644 index 00000000000..4d240035523 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -0,0 +1,49 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +public class MeanAbsoluteError extends Loss { + + /** + * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction of + * {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public MeanAbsoluteError(Ops tf) { + super(tf); + } + + /** + * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanAbsoluteError(Ops tf, Reduction reduction) { + super(tf, null,reduction); + } + + /** + * Creates a MeanAbsoluteError + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** + * {@inheritDoc} + */ + @Override + public Operand call(Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.mean_absolute_error(tf, labels, predictions); + return super.computeWeightedLoss(losses, getReduction(), sampleWeights); + } + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java new file mode 100644 index 00000000000..af408abe36b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -0,0 +1,19 @@ +package org.tensorflow.framework.losses; + +/** Type of Loss Reduction */ +public enum Reduction { + AUTO, + NONE, + SUM, + SUM_OVER_BATCH_SIZE; + + /** + * Get the Reduction based on name + * + * @param name the name of the reduction + * @return the Reduction + */ + public static Reduction ofName(String name) { + return Reduction.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java new file mode 100644 index 00000000000..013c9c6d6a6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java @@ -0,0 +1,81 @@ +package org.tensorflow.framework.losses.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; + +public class ConfusionMatrix { + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param tf the TensorFlowOps + * @param labels Label values, a `Tensor` whose dimensions match + * `predictions`. + * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. + * @return `labels` and `predictions`, possibly with last dim squeezed. + */ + public static Tuple removeSqueezableDimensions(Ops tf, Operand labels, + Operand predictions) { + return removeSqueezableDimensions(tf, labels, predictions, 0); + } + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param tf the TensorFlowOps + * @param labels Label values, a `Tensor` whose dimensions match + * `predictions`. + * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. + * @param expectedRankDiff Expected result of `rank(predictions) - + * rank(labels)`. + * @return `labels` and `predictions`, possibly with last dim squeezed. + */ + public static Tuple removeSqueezableDimensions(Ops tf, Operand labels, + Operand predictions, int expectedRankDiff) { + + tf = tf.withSubScope("removeSqueezableDimensions"); + Shape predictionsShape = predictions.asOutput().shape(); + int predictionsRank = predictionsShape.numDimensions(); + Shape labelsShape = labels.asOutput().shape(); + int labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + // Use static rank. + int rankDiff = predictionsRank - labelsRank; + if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { + predictions = tf.squeeze(predictions); + } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { + labels = tf.squeeze(labels); + } + return new Tuple(labels, predictions); + } + // Use dynamic rank. + Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { + /** + * TODO, if we ever get a select that does lazy evaluation, but for + * now do the tf.squeeze predictions = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), + * predictions ); * + */ + predictions = tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))); + } + if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) { + /** + * TODO, if we ever get a select that does lazy evaluation labels = + * tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff + * ), tf.squeeze(labels, Squeeze.axis(Arrays.asList(-1L))), + * predictions ); * + */ + labels = tf.squeeze(labels, Squeeze.axis(Arrays.asList(-1L))); + } + return new Tuple(labels, predictions,true); + } + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java new file mode 100644 index 00000000000..272e27b51ff --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -0,0 +1,142 @@ +package org.tensorflow.framework.losses.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; + +public class LossesImpl { + + /** + * Squeeze or expand last dimension if needed. + * + *
    + *
  1. Squeezes last dim of `yPred` or `yTrue` if their rank differs by 1 (using + * `confusion_matrix.remove_squeezable_dimensions`). + *
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 from the new + * rank of `yPred`. If `sample_weight` is scalar, it is kept scalar./li> + *
+ * + * @param tf the TensorFlow Ops + * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. + * @param predictions Optional label `Tensor` whose dimensions match `y_pred`. + * @return Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has the last + * dimension squeezed, `sample_weight` could be extended by one dimension. If `sample_weight` + * is null, (y_pred, y_true) is returned. + */ + /********** TODO need to move ConfusionMatrix to MathOps */ + public static Tuple squeezeOrExpandDimensions(Ops tf, Operand labels, Operand predictions) { + return squeezeOrExpandDimensions(tf, labels, predictions, null); + } + + /** + * Squeeze or expand last dimension if needed. * * + * + *
    + *
  1. Squeezes last dim of `yPred` or `yTrue` if their rank differs by 1 (using * + * `confusion_matrix.remove_squeezable_dimensions`). * + *
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 from the new * + * rank of `yPred`. If `sample_weight` is scalar, it is kept scalar./li> * + *
+ * + * @param tf the TensorFlow Ops + * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. + * @param labels Optional label `Tensor` whose dimensions match `y_pred`. + * @param sampleWeight Optional weight scalar or `Tensor` whose dimensions match `y_pred`. + * @return Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has the last + * dimension squeezed, `sample_weight` could be extended by one dimension. If `sample_weight` + * is null, (y_pred, y_true) is returned. + */ + /********** TODO need to move ConfusionMatrix to MathOps **/ + public static Tuple squeezeOrExpandDimensions( + Ops tf, Operand labels, Operand predictions, Operand sampleWeight) { + Tuple tuple = new Tuple<>(labels, predictions, true); + Shape predictionsShape = predictions.asOutput().shape(); + long ypredRank = predictionsShape.numDimensions(); + + if (labels != null) { + Shape labelsShape = labels.asOutput().shape(); + long ytrueRank = labelsShape.numDimensions(); + if (ytrueRank != Shape.UNKNOWN_SIZE && ypredRank != Shape.UNKNOWN_SIZE) { + // Use static rank for `y_true` and `y_pred`. + if (ypredRank - ytrueRank != 1 || predictionsShape.size(-1) == 1) { + // y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(y_true, y_pred) + tuple = ConfusionMatrix.removeSqueezableDimensions(tf, labels, predictions); + } + } else { // use dynamic rank + tuple = ConfusionMatrix.removeSqueezableDimensions(tf, labels, predictions); + } + } + if (sampleWeight == null) { + return tuple; + } + Shape weightsShape = sampleWeight.asOutput().shape(); + long weightsRank = weightsShape.numDimensions(); + if (weightsRank == 0) { // scalar + return new Tuple(labels, predictions, sampleWeight, true); + } + + if (ypredRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { + + if (weightsRank - ypredRank == 1) { + sampleWeight = tf.squeeze(sampleWeight); + } else if (ypredRank - weightsRank == 1) { + sampleWeight = tf.expandDims(sampleWeight, tf.constant(-1L)); + } + return new Tuple(labels, predictions, sampleWeight, true); + } + // Use dynamic rank. + Operand weightsRankTensor = tf.rank(sampleWeight); + Operand rankDiff = tf.math.sub(weightsRankTensor, tf.rank(predictions)); + sampleWeight = + tf.select( + tf.math.equal(weightsRankTensor, tf.constant(0)), + sampleWeight, + maybeAdjustWeights(tf, sampleWeight, rankDiff)); + return new Tuple(labels, predictions, sampleWeight, true); + } + + + /** + * Squeeze or expand the sampleWeight based on the rank difference + * + *

If the rank difference is +1, squeeze the last dimension of sampleWeight, If the rank + * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of + * sampleWeight as is. + * + * @param tf the TensorFlow Ops + * @param sampleWeight the sample weights + * @param rankDiff the difference in rank + * @param the data type for the Operands. + * @return the adjusted sampleWeight + */ + private static Operand maybeAdjustWeights( + Ops tf, Operand sampleWeight, Operand rankDiff) { + return tf.select( + tf.math.equal(rankDiff, tf.constant(1)), + tf.squeeze(sampleWeight, Squeeze.axis(Arrays.asList(-1L))), + maybeExpandWeights(tf, sampleWeight, rankDiff)); + } + + /** + * Expand the last dimension of sampleWeight. if the rank difference is -1. + * + * @param tf the TensorFlow Ops + * @param sampleWeight the sample weights + * @param rankDiff the difference in rank + * @param the data type for the Operands. + * @return the adjusted sampleWeight + */ + private static Operand maybeExpandWeights( + Ops tf, Operand sampleWeight, Operand rankDiff) { + return tf.select( + tf.math.equal(rankDiff, tf.constant(-1)), + tf.expandDims(sampleWeight, tf.constant(-1)), + sampleWeight); + } + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java new file mode 100644 index 00000000000..672c4ca4f6c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java @@ -0,0 +1,107 @@ +package org.tensorflow.framework.losses.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +/** + * A helper class for loss methods to return multiple responses + * + * @param the data type of the Tuple entries. + */ +public class Tuple { + private final Operand labels; + private final Operand losses; + private final Operand predictions; + private final Operand sampleWeights; + + /** + * Creates a Tuple of Operands for labels, predictions, and sampleWeights + * + * @param labels the labels + * @param lossesOrPredictions the losses or predictions + * @param isPredictions flag indicating that this Tuple will contain predictions or losses + */ + public Tuple(Operand labels, Operand lossesOrPredictions, boolean isPredictions) { + this(labels, lossesOrPredictions, null, isPredictions); + } + + /** + * Creates a Tuple of Operands for labels, predictions, and sampleWeights + * + * @param labels the labels + * @param lossesOrPredictions the losses or predictions + * @param sampleWeights the sample weights + * @param isPredictions flag indicating that this Tuple will contain predictions or losses + */ + public Tuple( + Operand labels, + Operand lossesOrPredictions, + Operand sampleWeights, + boolean isPredictions) { + this.labels = labels; + if (isPredictions) { + this.predictions = lossesOrPredictions; + this.losses = null; + } else { + this.predictions = null; + this.losses = lossesOrPredictions; + } + this.sampleWeights = sampleWeights; + } + + /** + * Indicates whether this Tuple contains Labels + * + * @return true is this Tuple contains Labels + */ + public boolean containsLabels() { + return labels != null; + } + + /** + * Indicates whether this Tuple contains Labels + * + * @return true is this Tuple contains Labels + */ + public boolean containsPredictions() { + return predictions != null; + } + + /** + * Indicates whether this Tuple contains Labels + * + * @return true is this Tuple contains Labels + */ + public boolean containsLosses() { + return losses != null; + } + + /** + * Indicates whether this Tuple contains Labels + * + * @return true is this Tuple contains Labels + */ + public boolean containsSampleWeights() { + return this.sampleWeights != null; + } + + /** @return the labels */ + public Operand getLabels() { + return labels; + } + + /** @return the predictions */ + public Operand getPredictions() { + return predictions; + } + + /** @return the predictions */ + public Operand getLosses() { + return losses; + } + + /** @return the sampleWeights */ + public Operand getSampleWeights() { + return sampleWeights; + } +} From 2508f5e58b59e18d3537d845491ce1e3f7afbd85 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 8 Oct 2020 13:07:38 -0400 Subject: [PATCH 02/26] Initial Checkin for losses --- .../framework/losses/BinaryCrossentropy.java | 179 +++++ .../losses/CategoricalCrossentropy.java | 219 ++++++ .../framework/losses/CategoricalHinge.java | 91 +++ .../framework/losses/CosineSimilarity.java | 164 +++++ .../tensorflow/framework/losses/Hinge.java | 92 +++ .../tensorflow/framework/losses/Huber.java | 124 ++++ .../framework/losses/KLDivergence.java | 93 +++ .../tensorflow/framework/losses/LogCosh.java | 99 +++ .../org/tensorflow/framework/losses/Loss.java | 11 +- .../tensorflow/framework/losses/Losses.java | 683 +++++++++++++++++- .../framework/losses/MeanAbsoluteError.java | 104 ++- .../losses/MeanAbsolutePercentageError.java | 89 +++ .../framework/losses/MeanSquaredError.java | 89 +++ .../losses/MeanSquaredLogarithmicError.java | 89 +++ .../tensorflow/framework/losses/Poisson.java | 99 +++ .../framework/losses/Reduction.java | 14 +- .../losses/SparseCategoricalCrossentropy.java | 170 +++++ .../framework/losses/SquaredHinge.java | 93 +++ .../losses/impl/ConfusionMatrix.java | 81 --- .../framework/losses/impl/LossesImpl.java | 398 +++++++--- .../framework/losses/impl/Tuple.java | 80 +- .../losses/BinaryCrossentropyTest.java | 179 +++++ .../losses/CategoricalCrossentropyTest.java | 213 ++++++ .../losses/CategoricalHingeTest.java | 131 ++++ .../losses/CosineSimilarityTest.java | 171 +++++ .../framework/losses/HingeTest.java | 108 +++ .../framework/losses/HuberTest.java | 123 ++++ .../framework/losses/KLDivergenceTest.java | 106 +++ .../framework/losses/LogCoshTest.java | 105 +++ .../losses/MeanAbsoluteErrorTest.java | 180 +++++ .../MeanAbsolutePercentageErrorTest.java | 153 ++++ .../losses/MeanSquaredErrorTest.java | 180 +++++ .../MeanSquaredLogarithmicErrorTest.java | 179 +++++ .../framework/losses/PoissonTest.java | 105 +++ .../SparseCategoricalCrossentropyTest.java | 180 +++++ .../framework/losses/SquaredHingeTest.java | 105 +++ 36 files changed, 4960 insertions(+), 319 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java new file mode 100644 index 00000000000..aa4e167c149 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -0,0 +1,179 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the cross-entropy loss between true labels and predicted labels. + * + *

Use this cross-entropy loss when there are only two label classes (assumed to be 0 and 1). For + * each example, there should be a single floating-point value per prediction. + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
+ *    BinaryCrossentropy bce = new BinaryCrossentropy(tf);
+ *    Operand<TFloat32> result = bce.call(labels, predictions);
+ *    // produces 0.815
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
+ *    Operand<TFloat32> result = bce.call(labels, predictions, sampleWeight);
+ *    // produces 0.458f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = bce.call(labels, predictions);
+ *    // produces 1.630f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    BinaryCrossentropy bce = new BinaryCrossentropy(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = bce.call(labels, predictions);
+ *    // produces [0.916f, 0.714f]
+ * 
+ * + */ +public class BinaryCrossentropy extends Loss { + public static final boolean FROM_LOGITS_DEFAULT = false; + public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + + private final boolean fromLogits; + private final float labelSmoothing; + + /** + * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link + * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a + * Loss Reduction of {@link * Reduction#AUTO} + * + * + * + * @param tf the TensorFlow Ops + */ + public BinaryCrossentropy(Ops tf) { + this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + } + + /** + * Creates a Binary Crossentropy loss using {@link Class#getSimpleName()} as the loss name, {@link + * #FROM_LOGITS_DEFAULT} for fromLogits, and {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public BinaryCrossentropy(Ops tf, Reduction reduction) { + this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction); + } + + /** + * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, + * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link #REDUCTION_DEFAULT}, + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + */ + public BinaryCrossentropy(Ops tf, boolean fromLogits) { + this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + } + + /** + * Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a + * reduction of {@link #REDUCTION_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param fromLogits Whether to interpret predictions as a tensor of logit values + */ + public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { + this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT); + } + + /** + * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, + * and a reduction of {@link #REDUCTION_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + */ + public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { + this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + } + + /** + * Creates a Binary Crossentropy loss using a reduction of {@link #REDUCTION_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + */ + public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { + this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT); + } + + /** + * Creates a Binary Crossentropy loss + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + * @param reduction Type of Reduction to apply to the loss. + */ + public BinaryCrossentropy( + Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(tf, null, fromLogits, labelSmoothing, reduction); + } + + /** + * Creates a Binary Crossentropy loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + * @param reduction Type of Reduction to apply to the loss. + */ + public BinaryCrossentropy( + Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { + super(tf, name, reduction); + this.fromLogits = fromLogits; + this.labelSmoothing = labelSmoothing; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = + Losses.binaryCrossentropy(tf, labels, predictions, fromLogits, labelSmoothing); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java new file mode 100644 index 00000000000..b042a656405 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -0,0 +1,219 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the crossentropy loss between the labels and predictions. + * + *

Use this crossentropy loss function when there are two or more label classes. We expect labels + * to be provided in a one_hot representation. If you want to provide labels as integers, please use + * {@link SparseCategoricalCrossentropy} loss. There should be # classes floating point + * values per feature. + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0, 1, 0}, {0, 0, 1}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}});
+ *    CategoricalCrossentropy cce = new CategoricalCrossentropy(tf);
+ *    Operand<TFloat32> result = cce.call(labels, predictions);
+ *    // produces 1.177
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
+ *    Operand<TFloat32> result = cce.call(labels, predictions, sampleWeight);
+ *    // produces 0.814f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    CategoricalCrossentropy cce = new CategoricalCrossentropy(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = cce.call(labels, predictions);
+ *    // produces 2.354f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    CategoricalCrossentropy cce =
+ *        new CategoricalCrossentropy(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = cce.call(labels, predictions);
+ *    // produces [0.0513f, 2.303f]
+ * 
+ */ +public class CategoricalCrossentropy extends Loss { + public static final boolean FROM_LOGITS_DEFAULT = false; + public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + public static final int DEFAULT_AXIS = -1; + + private final boolean fromLogits; + private final float labelSmoothing; + private final int axis; + + /** + * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, + * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and an axis of {@link + * #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + */ + public CategoricalCrossentropy(Ops tf) { + this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using {@link #FROM_LOGITS_DEFAULT} for fromLogits, + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link * + * Reduction#AUTO}, and an axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of this loss + */ + public CategoricalCrossentropy(Ops tf, String name) { + this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, + * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing and an axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to loss. + */ + public CategoricalCrossentropy(Ops tf, Reduction reduction) { + this(tf, null, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link + * #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, and an axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of this loss + * @param reduction Type of Reduction to apply to loss. + */ + public CategoricalCrossentropy(Ops tf, String name, Reduction reduction) { + this(tf, name, FROM_LOGITS_DEFAULT, LABEL_SMOOTHING_DEFAULT, reduction, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link * + * Reduction#AUTO}, and an axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + */ + public CategoricalCrossentropy(Ops tf, boolean fromLogits) { + this(tf, null, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for + * labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and a channel axis of {@link + * #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of this loss + * @param fromLogits Whether to interpret predictions as a tensor of logit values + */ + public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { + this(tf, name, fromLogits, LABEL_SMOOTHING_DEFAULT, REDUCTION_DEFAULT, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, + * a Loss Reduction of {@link * Reduction#AUTO}, and a channel axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the + * loss between the predicted labels and a smoothed version of the true labels, where the + * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to + * heavier smoothing. + */ + public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { + this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link * Reduction#AUTO}, + * and a channel axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of this loss + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the + * loss between the predicted labels and a smoothed version of the true labels, where the + * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to + * heavier smoothing. + */ + public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { + this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name + * and a channel axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the + * loss between the predicted labels and a smoothed version of the true labels, where the + * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to + * heavier smoothing. + * @param reduction Type of Reduction to apply to loss. + */ + public CategoricalCrossentropy( + Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(tf, null, fromLogits, labelSmoothing, reduction, DEFAULT_AXIS); + } + + /** + * Creates a categorical cross entropy Loss + * + * @param tf the TensorFlow Ops + * @param name the name of this loss + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the + * loss between the predicted labels and a smoothed version of the true labels, where the + * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to + * heavier smoothing. + * @param reduction Type of Reduction to apply to loss. + * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' + * and axis=1 corresponds to data format 'Channels First'. + */ + public CategoricalCrossentropy( + Ops tf, + String name, + boolean fromLogits, + float labelSmoothing, + Reduction reduction, + int axis) { + super(tf, name, reduction); + this.fromLogits = fromLogits; + this.labelSmoothing = labelSmoothing; + this.axis = axis; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = + Losses.categoricalCrossentropy(tf, labels, predictions, fromLogits, labelSmoothing, axis); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java new file mode 100644 index 00000000000..6c828fd2d16 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -0,0 +1,91 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the categorical hinge loss between labels and predictions. + * + *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) + * and pos=sum(labels*predictions) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0, 1}, {0, 0}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
+ *    CategoricalHinge categoricalHinge = new CategoricalHinge(tf);
+ *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
+ *    // produces 1.4
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1f, 0.f});
+ *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions, sampleWeight);
+ *    // produces 0.6f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    CategoricalHinge categoricalHinge = new CategoricalHinge(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
+ *    // produces 2.8f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    CategoricalHinge categoricalHinge =
+ *        new CategoricalHinge(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = categoricalHinge.call(labels, predictions);
+ *    // produces [1.2f, 1.6f]
+ * 
+ */ +public class CategoricalHinge extends Loss { + + /** + * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name and a + * Loss Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public CategoricalHinge(Ops tf) { + super(tf); + } + + /** + * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public CategoricalHinge(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a Categorical Hinge + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public CategoricalHinge(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.categoricalHinge(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java new file mode 100644 index 00000000000..e5d9d6a5d7b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -0,0 +1,164 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the cosine similarity between labels and predictions. + * + *

Note that it is a negative quantity between -1 and 0, where 0 indicates orthogonality and + * values closer to -1 indicate greater similarity. This makes it usable as a loss function in a + * setting where you try to maximize the proximity between predictions and targets. If either labels + * or predictions is a zero vector, cosine similarity will be 0 regardless of the proximity between + * predictions and targets. + * + *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {1.f, 1.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 0.f}, {1.f, 1.f}});
+ *    CosineSimilarity cosineLoss = new CosineSimilarity(tf);
+ *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
+ *    // produces -0.5
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
+ *    Operand<TFloat32> result = cosineLoss.call(labels, predictions, sampleWeight);
+ *    // produces -0.0999f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
+ *    // produces -0.999f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    CosineSimilarity cosineLoss = new CosineSimilarity(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = cosineLoss.call(labels, predictions);
+ *    // produces [-0.f, -0.999f]
+ * 
+ */ +public class CosineSimilarity extends Loss { + public static final int DEFAULT_AXIS = -1; + public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; + + private final int axis; + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis + * of {@link #DEFAULT_AXIS}, and a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + */ + public CosineSimilarity(Ops tf) { + + this(tf, null, DEFAULT_AXIS, DEFAULT_REDUCTION); + } + + /** + * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS}, and a Loss Reduction + * of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + */ + public CosineSimilarity(Ops tf, String name) { + + this(tf, name, DEFAULT_AXIS, DEFAULT_REDUCTION); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a + * Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, int axis) { + + this(tf, null, axis, DEFAULT_REDUCTION); + } + + /** + * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, String name, int axis) { + + this(tf, name, axis, DEFAULT_REDUCTION); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an + * axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, Reduction reduction) { + + this(tf, null, DEFAULT_AXIS, reduction); + } + + /** + * Creates a Cosine Similarity Loss using an axis of {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, Reduction reduction) { + + this(tf, name, DEFAULT_AXIS, reduction); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + + this(tf, null, axis, reduction); + } + + /** + * Creates a Cosine Similarity Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { + super(tf, name, reduction); + this.axis = axis; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.cosineSimilarity(tf, labels, predictions, axis); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java new file mode 100644 index 00000000000..5209d8df360 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -0,0 +1,92 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the hinge loss between labels and predictions. + * + *

loss = maximum(1 - labels * predictions, 0)

. + * + *

labels/code> values are expected to be -1 or 1. + * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
+ *    Hinge hingeLoss = new Hinge(tf);
+ *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
+ *    // produces 1.3f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
+ *    Operand<TFloat32> result = hingeLoss.call(labels, predictions, sampleWeight);
+ *    // produces 0.55f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    Hinge hingeLoss = new Hinge(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
+ *    // produces 2.6f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    Hinge hingeLoss = new Hinge(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = hingeLoss.call(labels, predictions);
+ *    // produces [1.1f, 1.5f]
+ * 
+ */ +public class Hinge extends Loss { + + /** + * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction + * of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public Hinge(Ops tf) { + this(tf, null, Reduction.AUTO); + } + + /** + * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public Hinge(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a Hinge + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public Hinge(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.hinge(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java new file mode 100644 index 00000000000..3b2949eeb03 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -0,0 +1,124 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the Huber loss between labels and predictions. + * + *

For each value x in error = y_true - y_pred: + * + *

+ *     loss = 0.5 * x^2                  if |x| <= d
+ *     loss = 0.5 * d^2 + d * (|x| - d)  if |x| > d
+ * 
+ * + *

where d is delta. + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
+ *    Huber huberLoss = new Huber(tf);
+ *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
+ *    // produces 0.155
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
+ *    Operand<TFloat32> result = huberLoss.call(labels, predictions, sampleWeight);
+ *    // produces 0.09f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    Huber huberLoss = new Huber(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
+ *    // produces 0.32f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    Huber huberLoss = new Huber(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = huberLoss.call(labels, predictions);
+ *    // produces [0.18f, 0.13f]
+ * 
+ * + * @see Huber loss + */ +public class Huber extends Loss { + public static final float DELTA_DEFAULT = 1.0f; + + private final float delta; + + /** + * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name, {@link + * #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public Huber(Ops tf) { + this(tf, null, DELTA_DEFAULT, Reduction.AUTO); + } + + /** + * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link + * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public Huber(Ops tf, String name) { + this(tf, name, DELTA_DEFAULT, Reduction.AUTO); + } + + /** + * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name and and {@link + * #DELTA_DEFAULT} as the delta + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public Huber(Ops tf, Reduction reduction) { + this(tf, null, DELTA_DEFAULT, reduction); + } + + /** + * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public Huber(Ops tf, String name, Reduction reduction) { + this(tf, name, DELTA_DEFAULT, reduction); + } + + /** + * Creates a Huber Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param delta the point where the Huber loss function changes from quadratic to linear. + * @param reduction Type of Reduction to apply to the loss. + */ + public Huber(Ops tf, String name, float delta, Reduction reduction) { + super(tf, name, reduction); + this.delta = delta; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.huber(tf, labels, predictions, delta); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java new file mode 100644 index 00000000000..71c348069bd --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -0,0 +1,93 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes Kullback-Leibler divergence loss between labels and predictions. + * + *

loss = labels * log(labels / predictions) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
+ *    KLDivergence kld = new KLDivergence(tf);
+ *    Operand<TFloat32> result = kld.call(labels, predictions);
+ *    // produces 0.458
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
+ *    Operand<TFloat32> result = kld.call(labels, predictions, sampleWeight);
+ *    // produces 0.366f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    KLDivergence kld = new KLDivergence(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = kld.call(labels, predictions);
+ *    // produces 0.916f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    KLDivergence kld = new KLDivergence(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = kld.call(labels, predictions);
+ *    // produces [0.916f, -3.08e-06f]
+ * 
+ * + * @see Kullback?Leibler + * divergence + */ +public class KLDivergence extends Loss { + + /** + * Creates a Kullback Leibler Divergence Loss using {@link Class#getSimpleName()} as the loss name + * and a Loss Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public KLDivergence(Ops tf) { + super(tf); + } + + /** + * Creates a Kullback Leibler Divergence Loss Loss using {@link Class#getSimpleName()} as the loss + * name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public KLDivergence(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a Kullback Leibler Divergence Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public KLDivergence(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.kullbackLeiblerDivergence(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java new file mode 100644 index 00000000000..6ddb0b2daac --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -0,0 +1,99 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes Computes the logarithm of the hyperbolic cosine of the prediction error. + * + *

logcosh = log((exp(x) + exp(-x))/2), where x is the error + * predictions - y_true. + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}});
+ *    LogCosh logcosh = new LogCosh(tf);
+ *    Operand<TFloat32> result = logcosh.call(labels, predictions);
+ *    // produces 0.108
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
+ *    Operand<TFloat32> result = logcosh.call(labels, predictions, sampleWeight);
+ *    // produces 0.087f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    LogCosh logcosh = new LogCosh(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = logcosh.call(labels, predictions);
+ *    // produces 0.217f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    LogCosh logcosh = new LogCosh(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = logcosh.call(labels, predictions);
+ *    // produces [0.217f, 0f]
+ * 
+ */ +public class LogCosh extends Loss { + + /** + * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name and a Loss + * Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public LogCosh(Ops tf) { + this(tf, null, Reduction.AUTO); + } + + /** + * Creates a LogCosh Loss using a Loss Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public LogCosh(Ops tf, String name) { + this(tf, name, Reduction.AUTO); + } + + /** + * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public LogCosh(Ops tf, Reduction reduction) { + this(tf, null, reduction); + } + + /** + * Creates a Kullback Leibler Divergence Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public LogCosh(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.logCosh(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 95d79507f90..9c0976f2c6f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -48,7 +48,7 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param The data type of the labels, predictions and loss. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return call(labels, predictions, null); } @@ -62,13 +62,14 @@ public Operand call(Operand labels, Operand predict * of size [batch_size], then the total loss for each sample of the batch is rescaled by the * corresponding element in the sample_weight vector. If the shape of sample_weight is * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * y_pred is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param The data type of the labels, predictions, sampleWeights and loss. + * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 65addd71c90..604ffc1b474 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -2,28 +2,681 @@ import org.tensorflow.DataType; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.framework.losses.impl.Tuple; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceAll; +import org.tensorflow.op.core.ReduceMax; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.math.Sigmoid; +import org.tensorflow.op.math.Softplus; +import org.tensorflow.op.nn.Softmax; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; +/** Built-in loss functions. */ public class Losses { - /** - * - * @param tf - * @param yTrue - * @param yPred - * @param - * @return - */ - public static Operand mean_absolute_error(Ops tf, Operand labels, Operand predictions) { - Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); - Tuple ops = squeezeOrExpandDimensions(tf, labels, predictions, null); - predictions = ops.getPredictions(); - tLabels = ops.getLabels(); - return mean(tf, tf.math.abs(tf.math.sub(yPred, yTrue)), tf.constant(-1)); + /** Default Fuzz factor. */ + public static final float EPSILON = 1e-7f; + + /** + * Calculates the mean absolute error between labels and predictions. + * + *

loss = reduceMean(abs(labels - predictions)) + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param the data type of the result + * @param the data type of the labels + * @return the mean absolute error + */ + public static Operand meanAbsoluteError( + Ops tf, Operand labels, Operand predictions) { + Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = ops.getTarget(); + tLabels = ops.getLabels(); + return tf.math.mean( + tf.math.abs(tf.math.sub(tLabels, predictions)), tf.constant(-1), Mean.keepDims(false)); + } + + /** + * Computes the mean squared error between labels and predictions. + * + *

loss = reduceMean(square(labels - predictions)) + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param the data type of the result + * @param the data type of the labels + * @return the mean squared error + */ + public static Operand meanSquaredError( + Ops tf, Operand labels, Operand predictions) { + Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = ops.getTarget(); + tLabels = ops.getLabels(); + return tf.math.mean(tf.math.squaredDifference(predictions, tLabels), tf.constant(-1)); + } + + /** + * Calculates the mean absolute percentage error between labels and predictions. + * + *

loss = 100 * reduceMean(abs((labels - predictions) / labels)) + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param the data type of the result + * @param the data type of the labels + * @return the mean absolute percentage error + */ + public static Operand meanAbsolutePercentageError( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = ops.getTarget(); + tLabels = ops.getLabels(); + Operand diff = + tf.math.abs( + tf.math.div( + tf.math.sub(tLabels, predictions), + tf.math.maximum( + tf.math.abs(tLabels), tf.dtypes.cast(tf.constant(EPSILON), dataType)))); + return tf.math.mul( + tf.dtypes.cast(tf.constant(100), dataType), tf.math.mean(diff, tf.constant(-1))); + } + + /** + * Calculates the mean squared logarithmic percentage error between labels and predictions. + * + *

loss = reduceMean(square(log(labels + 1) - log(predictions + 1))) + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param the data type of the result + * @param the data type of the labels + * @return the mean squared logarithmic percentage error + */ + public static Operand meanSquaredLogarithmicError( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = ops.getTarget(); + tLabels = ops.getLabels(); + + Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + + Operand firstLog = tf.math.log(tf.math.add(tf.math.maximum(predictions, epsilonConst), one)); + Operand secondLog = tf.math.log(tf.math.add(tf.math.maximum(tLabels, epsilonConst), one)); + + return tf.math.mean(tf.math.squaredDifference(firstLog, secondLog), tf.constant(-1)); + } + + /** + * Computes the binary crossentropy loss between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing A number in the range [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing + * correspond to heavier smoothing. + * @param the data type of the predictions and labels + * @return the binary crossentropy loss. + */ + public static Operand binaryCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = ops.getTarget(); + tLabels = ops.getLabels(); + + if (labelSmoothing != 0.0f) { + tLabels = smoothLabelsBinaryX(tf, tLabels, labelSmoothing); + } + Operand bce = binaryCrossentropy(tf, tLabels, predictions, fromLogits); + return tf.math.mean(bce, tf.constant(-1)); + } + + /** + * Compute binary crossentropy loss between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param target the target Operand + * @param output the output, either logits or a probability distribution + * @param fromLogits whether `output` is expected to be a logits tensor. By default, we consider + * that `output` encodes a probability distribution. + * @param the data type of the Operands + * @return the binary crossentropy loss. + */ + private static Operand binaryCrossentropy( + Ops tf, Operand target, Operand output, boolean fromLogits) { + if (fromLogits) { + return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + } + + if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { + // TODO - this does not work, cannot walk back, work around is only go back 1. + // output = backtrackIdentity(output); + if (output.op().type().equals(Sigmoid.OP_NAME)) { + if (output.op().numOutputs() != 1) + throw new IllegalArgumentException("output can only have 1 output"); + output = output.op().output(0); + return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + } + } + DataType dataType = output.asOutput().dataType(); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); + output = tf.clipByValue(output, epsilonConst, oneMinusEpsilonConst); + + // Compute cross entropy from probabilities. + Operand bce = tf.math.mul(target, tf.math.log(tf.math.add(output, epsilonConst))); + bce = + tf.math.add( + bce, + tf.math.mul( + tf.math.sub(one, target), + tf.math.log(tf.math.add(tf.math.sub(one, output), epsilonConst)))); + return tf.math.neg(bce); + } + + /** + * Computes the categorical crossentropy loss between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, compute the + * loss between the predicted labels and a smoothed version of the true labels, where the + * smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing correspond to + * heavier smoothing. + * @param axis the + * @param the data type of the predictions and labels + * @return the categorical crossentropy loss. + */ + public static Operand categoricalCrossentropy( + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing, + int axis) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = ops.getTarget(); + tLabels = ops.getLabels(); + + if (labelSmoothing != 0.0f) { + tLabels = smoothLabelsCatX(tf, tLabels, labelSmoothing); + } + if (fromLogits) { + return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + } + if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { + // TODO output = backtrackIdentity(output); doesn't seem to work with Java version. + if (predictions.op().type().equals("Softmax")) { + if (predictions.op().numOutputs() != 1) + throw new IllegalArgumentException("output can only have 1 output"); + predictions = predictions.op().output(0); + return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + } + } + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); + predictions = + tf.math.div( + predictions, tf.reduceSum(predictions, tf.constant(axis), ReduceSum.keepDims(true))); + predictions = tf.clipByValue(predictions, epsilonConst, oneMinusEpsilonConst); + + // Compute cross entropy from probabilities. + Operand cce = + tf.reduceSum( + tf.math.mul(tLabels, tf.math.log(predictions)), + tf.constant(axis), + ReduceSum.keepDims(false)); + return tf.math.neg(cce); + } + + /** + * Computes the categorical hinge loss between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param the data type of the predictions and labels + * @return the categorical hinge loss + */ + public static Operand categoricalHinge( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + + Operand pos = + tf.reduceSum( + tf.math.mul(tLabels, predictions), tf.constant(-1), ReduceSum.keepDims(Boolean.FALSE)); + Operand neg = + tf.reduceMax( + tf.math.mul(tf.math.sub(one, tLabels), predictions), + tf.constant(-1), + ReduceMax.keepDims(Boolean.FALSE)); + Operand sub = tf.math.sub(neg, pos); + Operand add = tf.math.add(sub, one); + return tf.math.maximum(zero, add); + } + + /** + * Computes the cosine similarity loss between labels and predictions. + * + *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 + * indicates orthogonality and values closer to -1 indicate greater similarity. The values closer + * to 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where + * you try to maximize the proximity between predictions and targets. If either labels or + * predictions is a zero vector, cosine similarity will be 0 regardless of the proximity between + * predictions and targets. + * + *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param axis Axis along which to determine similarity. + * @param the data type of the predictions and labels + * @return the cosine similarity loss + */ + public static Operand cosineSimilarity( + Ops tf, Operand labels, Operand predictions, int axis) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + + tLabels = l2Normalize(tf, tLabels, axis); + predictions = l2Normalize(tf, predictions, axis); + Operand mathMul = tf.math.mul(tLabels, predictions); + Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); + return tf.math.neg(sum); + } + + /** + * Computes the hinge loss between labels and predictions + * + *

loss = reduceMean(maximum(1 - labels * predictions, 0)) + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param the data type of the predictions and labels + * @return the hinge loss + */ + public static Operand hinge( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + + tLabels = maybeConvertLabels(tf, tLabels); + + return tf.math.mean( + tf.math.maximum(tf.math.sub(one, tf.math.mul(tLabels, predictions)), zero), + tf.constant(-1)); + } + + /** + * Computes the Huber loss between labels and predictions. + * + *

For each value x in error = labels - predictions: + * + *

+   *     loss = 0.5 * x^2                  if |x| <= d
+   *     loss = 0.5 * d^2 + d * (|x| - d)  if |x| > d
+   * 
+ * + *

where d is delta. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param delta the point where the Huber loss function changes from quadratic to linear. + * @param the data type of the predictions and labels + * @return the Huber loss + */ + public static Operand huber( + Ops tf, Operand labels, Operand predictions, float delta) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + + Operand error = tf.math.sub(predictions, tLabels); + Operand deltaConst = tf.dtypes.cast(tf.constant(delta), dataType); + Operand point5 = tf.dtypes.cast(tf.constant(0.5), dataType); + Operand absError = tf.math.abs(error); + Operand quadratic = tf.math.minimum(absError, deltaConst); + Operand linear = tf.math.sub(absError, quadratic); + Operand q2Point5 = tf.math.mul(point5, tf.math.mul(quadratic, quadratic)); + Operand deltaLinear = tf.math.mul(deltaConst, linear); + Operand loss = tf.math.add(q2Point5, deltaLinear); + return tf.math.mean(loss, tf.constant(-1)); + } + + /** + * Computes the Kullback-Leibler divergence loss between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param the data type of the predictions and labels + * @return the Kullback-Leibler divergence loss + * @see Kullback?Leibler + * divergence + */ + public static Operand kullbackLeiblerDivergence( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + + tLabels = tf.clipByValue(tLabels, epsilonConst, one); + predictions = tf.clipByValue(predictions, epsilonConst, one); + return tf.reduceSum( + tf.math.mul(tLabels, tf.math.log(tf.math.div(tLabels, predictions))), tf.constant(-1)); + } + + /** + * Computes the hyperbolic cosine loss between labels and predictions. + * + *

log(cosh(x)) is approximately equal to (x ** 2) / 2 for small + * x and to abs(x) - log(2) for large x. This means that + * 'logCosh' works mostly like the mean squared error, but will not be so strongly affected by the + * occasional wildly incorrect prediction. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param the data type of the predictions and labels + * @return the hyperbolic cosine divergence loss + */ + public static Operand logCosh( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand minusTwo = tf.dtypes.cast(tf.constant(-2), dataType); + Operand two = tf.dtypes.cast(tf.constant(2), dataType); + + Operand diff = tf.math.sub(predictions, tLabels); + Softplus softplus = tf.math.softplus(tf.math.mul(minusTwo, diff)); + Operand logcosh = tf.math.sub(tf.math.add(diff, softplus), tf.math.log(two)); + return tf.math.mean(logcosh, tf.constant(-1)); + } + + /** + * Computes the Poisson loss between labels and predictions. + * + *

The Poisson loss is the mean of the elements of the Tensor + * predictions - labels * log(predictions). + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param the data type of the predictions and labels + * @return the Poisson loss + */ + public static Operand poisson( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + + return tf.math.mean( + tf.math.sub( + predictions, tf.math.mul(tLabels, tf.math.log(tf.math.add(predictions, epsilonConst)))), + tf.constant(-1)); + } + + /** + * Computes the sparse categorical crossentropy loss between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param fromLogits Whether predictions is expected to be logits. By default, it is assumed that + * predictions encodes a probability distribution. + * @param axis The dimension along which the entropy is computed. + * @param the data type of the predictions and labels + * @return the sparse categorical crossentropy loss + */ + public static Operand sparseCategoricalCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + DataType dataType = predictions.asOutput().dataType(); + Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); + + if (!fromLogits && !(predictions instanceof Variable) && (!tf.scope().env().isEager())) { + // TODO output = backtrackIdentity(output); doesn't seem to work with Java version. + if (predictions.op().type().equals(Softmax.OP_NAME)) { + // When softmax activation function is used for output operation, we + // use logits from the softmax function directly to compute loss in order + // to prevent collapsing zero when training. + // TODO if( output.op().numOutputs() != 1) + // throw new IllegalArgumentException("output can only have 1 output"); + // TODO output = output.op.inputs[0] + fromLogits = true; + } } + if (!fromLogits) { + predictions = tf.clipByValue(predictions, epsilonConst, oneMinusEpsilonConst); + predictions = tf.math.log(predictions); + } + Shape outputShape = predictions.asOutput().shape(); + int outputRank = outputShape.numDimensions(); + axis %= outputRank; + if (axis < 0) { + axis += outputRank; + } + if (axis != outputRank - 1) { + int[] axisNew = moveAxisToEnd(axis, outputRank); + predictions = tf.linalg.transpose(predictions, tf.constant(axisNew)); + } + Operand iLabels = tf.dtypes.cast(labels, TInt64.DTYPE); + + // Try to adjust the shape so that rank of labels = rank of logits - 1. + Shape labelsShape = labels.asOutput().shape(); + int labelsRank = labelsShape.numDimensions(); + + boolean updateShape = labelsRank != outputRank - 1; + if (updateShape) { // TODO check to see if this is right + iLabels = tf.reshape(iLabels, tf.constant(-1)); // flatten one dimension + predictions = + tf.reshape( + predictions, + tf.constant(new long[] {-1L, outputShape.size(outputShape.numDimensions() - 1)})); + } + + + @SuppressWarnings("unchecked") + Operand loss = tf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); + if (updateShape && outputRank >= 3) { + Shape newShape = outputShape.take(outputShape.numDimensions() - 1); + loss = tf.reshape(loss, tf.constant(newShape)); + } + return loss; + } + + /** + * Computes the squared hinge loss between labels and predictions. + * + *

loss = reduceMean(square(maximum(1 - labels * predictions, 0))) + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param predictions the predictions + * @param the data type of the predictions and labels + * @return the squared hinge loss + */ + public static Operand squaredHinge( + Ops tf, Operand labels, Operand predictions) { + DataType dataType = predictions.asOutput().dataType(); + Operand tLabels = tf.dtypes.cast(labels, dataType); + Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + + tLabels = maybeConvertLabels(tf, tLabels); + return tf.math.mean( + tf.math.square(tf.math.maximum(tf.math.sub(one, tf.math.mul(tLabels, predictions)), zero)), + tf.constant(-1)); + } + + // private methods + + /** + * Smooths binary labels + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param labelSmoothing A number in the range [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing + * correspond to heavier smoothing. + * @param the data type of the labels + * @return the smoothed binary labels + */ + private static Operand smoothLabelsBinaryX( + Ops tf, Operand labels, float labelSmoothing) { + DataType dataType = labels.asOutput().dataType(); + Operand oneMinusSmoothing = tf.dtypes.cast(tf.constant(1.f - labelSmoothing), dataType); + Operand halfSmoothing = tf.dtypes.cast(tf.constant(0.5F * labelSmoothing), dataType); + return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), halfSmoothing); + } + + /** + * Smooths categorical labels + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param labelSmoothing A number in the range [0, 1]. When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing + * correspond to heavier smoothing. + * @param the data type of the labels + * @return the smoothed categorical labels + */ + private static Operand smoothLabelsCatX( + Ops tf, Operand labels, float labelSmoothing) { + DataType dataType = labels.asOutput().dataType(); + Operand smoothing = tf.dtypes.cast(tf.constant(labelSmoothing), dataType); + Shape labelsShape = labels.asOutput().shape(); + int numDims = labelsShape.numDimensions(); + Operand numClasses = tf.dtypes.cast(tf.constant(labelsShape.size(numDims - 1)), dataType); + Operand oneMinusSmoothing = tf.dtypes.cast(tf.constant(1.f - labelSmoothing), dataType); + return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); + } + + // TODO this was tf.math.l2_normalize in TF Python + /** + * Normalizes along dimension axis using an L2 norm. + * + * @param tf The TensorFlow Ops + * @param x the input + * @param axis Dimension along which to normalize. + * @return the normalized values based on L2 norm + */ + public static Operand l2Normalize(Ops tf, Operand x, int axis) { + Operand squareSum = + tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); + Operand invNorm = + tf.math.rsqrt( + tf.math.maximum( + squareSum, tf.dtypes.cast(tf.constant(1e-12F), x.asOutput().dataType()))); + return tf.math.mul(x, invNorm); + } + + /** + * Converts binary labels into -1/1. + * + * @param tf the TensorFlow Ops + * @param labels true targets + * @param the data type of the labels + * @return the labels, possibly converted into -1/1. + */ + private static Operand maybeConvertLabels(Ops tf, Operand labels) { + DataType dataType = labels.asOutput().dataType(); + + Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + Operand two = tf.dtypes.cast(tf.constant(2), dataType); + Operand areZeros = tf.math.equal(labels, zero); + Operand areOnes = tf.math.equal(labels, one); + Operand isBinary = + tf.reduceAll( + tf.math.logicalOr(areZeros, areOnes), tf.constant(-1), ReduceAll.keepDims(true)); + Operand convertBinaryLabels = tf.math.sub(tf.math.mul(two, labels), one); + return tf.select(isBinary, convertBinaryLabels, labels); + } + + /** + * Move the specified axis to end, to be used with transposes + * + * @param axis the axis to move + * @param outputRank the rank of the shape + * @return the new dimension array with the axis moved to the end. + */ + private static int[] moveAxisToEnd(int axis, int outputRank) { + int[] axisNew = new int[outputRank]; + for (int i = 0; i < axis; i++) { + axisNew[i] = i; + } + for (int i = axis + 1; i < outputRank; i++) { + axisNew[i - 1] = i; + } + axisNew[outputRank - 1] = axis; + return axisNew; + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 4d240035523..31592f8188b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -1,49 +1,89 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +/** + * Computes the mean of absolute difference between labels and predictions. + * + *

loss = abs(labels - predictions) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}});
+ *    MeanAbsoluteError mae = new MeanAbsoluteError(tf);
+ *    Operand<TFloat32> result = mae.call(labels, predictions);
+ *    // produces 0.5f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
+ *    Operand<TFloat32> result = mae.call(labels, predictions, sampleWeight);
+ *    // produces 0.25f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = mae.call(labels, predictions);
+ *    // produces 1.0f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    MeanAbsoluteError mae = new MeanAbsoluteError(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = mae.call(labels, predictions);
+ *    // produces [0.5f, 0.5f]
+ * 
+ */ public class MeanAbsoluteError extends Loss { /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction of - * {@link * Reduction#AUTO} + * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a + * Loss Reduction of {@link * Reduction#AUTO} * * @param tf the TensorFlow Ops */ public MeanAbsoluteError(Ops tf) { - super(tf); - } + super(tf); + } - /** - * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name - * - * @param tf the TensorFlow Ops - * @param reduction Type of Reduction to apply to the loss. - */ - public MeanAbsoluteError(Ops tf, Reduction reduction) { - super(tf, null,reduction); - } - - /** - * Creates a MeanAbsoluteError - * - * @param tf the TensorFlow Ops - * @param name the name of the loss - * @param reduction Type of Reduction to apply to the loss. - */ - public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { - super(tf, name, reduction); - } + /** + * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanAbsoluteError(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } - /** - * {@inheritDoc} - */ - @Override - public Operand call(Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.mean_absolute_error(tf, labels, predictions); - return super.computeWeightedLoss(losses, getReduction(), sampleWeights); - } + /** + * Creates a MeanAbsoluteError + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.meanAbsoluteError(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java new file mode 100644 index 00000000000..7e2ab3fa8ae --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -0,0 +1,89 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the mean absolute percentage error between labels and predictions. + * + *

loss = 100 * abs(labels - predictions) / labels + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{2.f, 1.f}, {2.f, 3.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}});
+ *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf);
+ *    Operand<TFloat32> result = mape.call(labels, predictions);
+ *    // produces 50f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
+ *    Operand<TFloat32> result = mape.call(labels, predictions, sampleWeight);
+ *    // produces 20f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = mape.call(labels, predictions);
+ *    // produces 100.0f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    MeanAbsolutePercentageError mape = new MeanAbsolutePercentageError(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = mape.call(labels, predictions);
+ *    // produces [25f, 75f]
+ * 
+ */ +public class MeanAbsolutePercentageError extends Loss { + + /** + * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name + * and a Loss Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public MeanAbsolutePercentageError(Ops tf) { + super(tf); + } + + /** + * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanAbsolutePercentageError(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a MeanAbsolutePercentageError + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.meanAbsolutePercentageError(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java new file mode 100644 index 00000000000..1b892d2fa16 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -0,0 +1,89 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the mean of squares of errors between labels and predictions. + * + *

loss = loss = square(labels - predictions) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}});
+ *    MeanSquaredError mse = new MeanSquaredError(tf);
+ *    Operand<TFloat32> result = mse.call(labels, predictions);
+ *    // produces 0.5f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
+ *    Operand<TFloat32> result = mse.call(labels, predictions, sampleWeight);
+ *    // produces 0.25f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = mse.call(labels, predictions);
+ *    // produces 1.0f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    MeanSquaredError mse = new MeanSquaredError(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = mse.call(labels, predictions);
+ *    // produces [0.5f, 0.5f]
+ * 
+ */ +public class MeanSquaredError extends Loss { + + /** + * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss + * Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public MeanSquaredError(Ops tf) { + super(tf); + } + + /** + * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanSquaredError(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a MeanSquaredError + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanSquaredError(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.meanSquaredError(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java new file mode 100644 index 00000000000..4efe1fb0b7b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -0,0 +1,89 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the mean squared logarithmic errors between labels and predictions. + * + *

loss = square(log(labels + 1.) - log(predictions + 1.)) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 1.f}, {1.f, 0.f}});
+ *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf);
+ *    Operand<TFloat32> result = msle.call(labels, predictions);
+ *    // produces 0.240f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.7f, 0.3f});
+ *    Operand<TFloat32> result = msle.call(labels, predictions, sampleWeight);
+ *    // produces 0.120f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = msle.call(labels, predictions);
+ *    // produces 0.480f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = msle.call(labels, predictions);
+ *    // produces [0.240f, 0.240f]
+ * 
+ */ +public class MeanSquaredLogarithmicError extends Loss { + + /** + * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss + * Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public MeanSquaredLogarithmicError(Ops tf) { + super(tf); + } + + /** + * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanSquaredLogarithmicError(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a MeanSquaredError + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.meanSquaredLogarithmicError(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java new file mode 100644 index 00000000000..c221f49eb90 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -0,0 +1,99 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * Computes the Poisson loss between labels and predictions. + * + *

loss = predictions - labels * log(predictions) + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0.f, 1.f}, {0.f, 0.f}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{1.f, 1.f}, {0.f, 0.f}});
+ *    Poisson poissonLoss = new Poisson(tf);
+ *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
+ *    // produces 0.5f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.8f, 0.2f});
+ *    Operand<TFloat32> result = poissonLoss.call(labels, predictions, sampleWeight);
+ *    // produces 0.4f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    Poisson poissonLoss = new Poisson(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
+ *    // produces 0.999f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    Poisson poissonLoss = new Poisson(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = poissonLoss.call(labels, predictions);
+ *    // produces [0.999f, 0f]
+ * 
+ */ +public class Poisson extends Loss { + + /** + * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name and a Loss + * Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public Poisson(Ops tf) { + this(tf, null, Reduction.AUTO); + } + + /** + * Creates a Poisson Loss using a Loss Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public Poisson(Ops tf, String name) { + this(tf, name, Reduction.AUTO); + } + + /** + * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public Poisson(Ops tf, Reduction reduction) { + this(tf, null, reduction); + } + + /** + * Creates a Poisson Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public Poisson(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.poisson(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java index af408abe36b..1e4573118c5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -1,6 +1,18 @@ package org.tensorflow.framework.losses; -/** Type of Loss Reduction */ +/** + * Type of Loss Reduction + * + *

{@link #AUTO} indicates that the reduction option will be determined by the usage context. For + * almost all cases this defaults to {@link #SUM_OVER_BATCH_SIZE}. + * + *

{@link #NONE} Weighted losses with one dimension reduced (axis=-1, or axis specified by loss + * function). + * + *

{@link #SUM} Scalar sum of weighted losses. + * + *

{@link #SUM_OVER_BATCH_SIZE} Scalar SUM divided by number of elements in losses. + */ public enum Reduction { AUTO, NONE, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java new file mode 100644 index 00000000000..f4af06e64a6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -0,0 +1,170 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the crossentropy loss between labels and predictions. + * + *

Use this crossentropy loss function when there are two or more label classes. The labels are + * expected to be provided as integers. If you want to provide labels using one-hot + * representation, please use {@link CategoricalCrossentropy} loss. There should be # classes + * floating point values per feature for predictions and a single floating + * point value per feature for label. + * + *

In the snippet below, there is a single floating point value per example for labels + * and # classes floating pointing values per example for predictions + * . The shape of labels is [batch_size] and the shape of + * predictions is [batch_size, num_classes]. + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[] {1, 2});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.05f, 0.95f, 0f}, {0.1f, 0.8f, 0.1f}});
+ *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf);
+ *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
+ *    // produces 1.177f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {0.3f, 0.7f});
+ *    Operand<TFloat32> result = sparseCCE.call(labels, predictions, sampleWeight);
+ *    // produces 0.814f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
+ *    // produces 2.354f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    SparseCategoricalCrossentropy sparseCCE = new SparseCategoricalCrossentropy(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = sparseCCE.call(labels, predictions);
+ *    // produces [0.0513f, 2.303f]
+ * 
+ */ +public class SparseCategoricalCrossentropy extends Loss { + public static final boolean FROM_LOGITS_DEFAULT = false; + public static final int AXIS_DEFAULT = -1; + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + + private final boolean fromLogits; + private final int axis; + + /** + * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss + * name, a Loss Reduction of {@link Reduction#AUTO}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * + * @param tf the TensorFlow Ops + */ + public SparseCategoricalCrossentropy(Ops tf) { + this(tf, null, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Reduction#AUTO}, + * and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param name the name of this loss function + */ + public SparseCategoricalCrossentropy(Ops tf, String name) { + this(tf, name, FROM_LOGITS_DEFAULT, REDUCTION_DEFAULT, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss + * name, with Reduction.AUTO and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to loss. + */ + public SparseCategoricalCrossentropy(Ops tf, Reduction reduction) { + this(tf, null, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy loss with Reduction.AUTO and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param name the name of this loss function + * @param reduction Type of Reduction to apply to loss. + */ + public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { + this(tf, name, FROM_LOGITS_DEFAULT, reduction, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Reduction#AUTO}, and + * fromLogits={@link #FROM_LOGITS_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param name the name of this loss function + * @param fromLogits Whether to interpret predictions as a tensor of logit values + */ + public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { + this(tf, name, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss + * name, a Loss Reduction of {@link Reduction#AUTO} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + */ + public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits) { + this(tf, null, fromLogits, REDUCTION_DEFAULT, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss + * name, + * + * @param tf the TensorFlow Ops + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param reduction Type of Reduction to apply to loss. + */ + public SparseCategoricalCrossentropy(Ops tf, boolean fromLogits, Reduction reduction) { + this(tf, null, fromLogits, reduction, AXIS_DEFAULT); + } + + /** + * Creates a SparseCategoricalCrossentropy + * + * @param tf the TensorFlow Ops + * @param name the name of this loss function + * @param fromLogits Whether to interpret predictions as a tensor of logit values + * @param reduction Type of Reduction to apply to loss. + * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' + * and axis=1 corresponds to data format 'Channels First'. + */ + public SparseCategoricalCrossentropy( + Ops tf, String name, boolean fromLogits, Reduction reduction, int axis) { + super(tf, name, reduction); + this.fromLogits = fromLogits; + this.axis = axis; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = + Losses.sparseCategoricalCrossentropy(tf, labels, predictions, fromLogits, axis); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java new file mode 100644 index 00000000000..a8ec49be835 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -0,0 +1,93 @@ +package org.tensorflow.framework.losses; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the squared hinge loss between labels and predictions. + * + *

loss = square(maximum(1 - labels * predictions, 0)) + * + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be + * converted to -1 or 1. + * + *

Standalone usage: + * + *

+ *    Operand<TFloat32> labels =
+ *        tf.constant(new float[][] {{0., 1.}, {0., 0.}});
+ *    Operand<TFloat32> predictions =
+ *        tf.constant(new float[][] {{0.6f, 0.4f}, {0.4f, 0.6f}});
+ *    SquaredHinge squaredHinge = new SquaredHinge(tf);
+ *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
+ *    // produces 1.86f
+ * 
+ * + *

Calling with sample weight: + * + *

+ *    Operand<TFloat32> sampleWeight = tf.constant(new float[] {1.f, 0.f});
+ *    Operand<TFloat32> result = squaredHinge.call(labels, predictions,
+ *                                                  sampleWeight);
+ *    // produces 0.73f
+ * 
+ * + *

Using SUM reduction type: + * + *

+ *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.SUM);
+ *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
+ *    // produces 3.72f
+ * 
+ * + *

Using NONE reduction type: + * + *

+ *    SquaredHinge squaredHinge = new SquaredHinge(tf, Reduction.NONE);
+ *    Operand<TFloat32> result = squaredHinge.call(labels, predictions);
+ *    // produces [1.46f, 2.26f]
+ * 
+ */ +public class SquaredHinge extends Loss { + + /** + * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss + * Reduction of {@link * Reduction#AUTO} + * + * @param tf the TensorFlow Ops + */ + public SquaredHinge(Ops tf) { + super(tf); + } + + /** + * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param reduction Type of Reduction to apply to the loss. + */ + public SquaredHinge(Ops tf, Reduction reduction) { + super(tf, null, reduction); + } + + /** + * Creates a Squared Hinge + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param reduction Type of Reduction to apply to the loss. + */ + public SquaredHinge(Ops tf, String name, Reduction reduction) { + super(tf, name, reduction); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + Operand losses = Losses.squaredHinge(tf, labels, predictions); + return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java deleted file mode 100644 index 013c9c6d6a6..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/ConfusionMatrix.java +++ /dev/null @@ -1,81 +0,0 @@ -package org.tensorflow.framework.losses.impl; - -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Squeeze; -import org.tensorflow.types.family.TNumber; - -import java.util.Arrays; - -public class ConfusionMatrix { - - /** - * Squeeze last dim if ranks differ from expected by exactly 1. - * - * @param tf the TensorFlowOps - * @param labels Label values, a `Tensor` whose dimensions match - * `predictions`. - * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. - * @return `labels` and `predictions`, possibly with last dim squeezed. - */ - public static Tuple removeSqueezableDimensions(Ops tf, Operand labels, - Operand predictions) { - return removeSqueezableDimensions(tf, labels, predictions, 0); - } - - /** - * Squeeze last dim if ranks differ from expected by exactly 1. - * - * @param tf the TensorFlowOps - * @param labels Label values, a `Tensor` whose dimensions match - * `predictions`. - * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. - * @param expectedRankDiff Expected result of `rank(predictions) - - * rank(labels)`. - * @return `labels` and `predictions`, possibly with last dim squeezed. - */ - public static Tuple removeSqueezableDimensions(Ops tf, Operand labels, - Operand predictions, int expectedRankDiff) { - - tf = tf.withSubScope("removeSqueezableDimensions"); - Shape predictionsShape = predictions.asOutput().shape(); - int predictionsRank = predictionsShape.numDimensions(); - Shape labelsShape = labels.asOutput().shape(); - int labelsRank = labelsShape.numDimensions(); - - if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { - // Use static rank. - int rankDiff = predictionsRank - labelsRank; - if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { - predictions = tf.squeeze(predictions); - } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { - labels = tf.squeeze(labels); - } - return new Tuple(labels, predictions); - } - // Use dynamic rank. - Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); - if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { - /** - * TODO, if we ever get a select that does lazy evaluation, but for - * now do the tf.squeeze predictions = tf.select( - * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), - * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), - * predictions ); * - */ - predictions = tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))); - } - if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) { - /** - * TODO, if we ever get a select that does lazy evaluation labels = - * tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff - * ), tf.squeeze(labels, Squeeze.axis(Arrays.asList(-1L))), - * predictions ); * - */ - labels = tf.squeeze(labels, Squeeze.axis(Arrays.asList(-1L))); - } - return new Tuple(labels, predictions,true); - } - -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 272e27b51ff..eb8032569c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -1,142 +1,304 @@ package org.tensorflow.framework.losses.impl; +import org.tensorflow.DataType; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Reduction; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; import org.tensorflow.op.core.Squeeze; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import java.util.Arrays; +import java.util.Collections; public class LossesImpl { - /** - * Squeeze or expand last dimension if needed. - * - *
    - *
  1. Squeezes last dim of `yPred` or `yTrue` if their rank differs by 1 (using - * `confusion_matrix.remove_squeezable_dimensions`). - *
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 from the new - * rank of `yPred`. If `sample_weight` is scalar, it is kept scalar./li> - *
- * - * @param tf the TensorFlow Ops - * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. - * @param predictions Optional label `Tensor` whose dimensions match `y_pred`. - * @return Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has the last - * dimension squeezed, `sample_weight` could be extended by one dimension. If `sample_weight` - * is null, (y_pred, y_true) is returned. - */ - /********** TODO need to move ConfusionMatrix to MathOps */ - public static Tuple squeezeOrExpandDimensions(Ops tf, Operand labels, Operand predictions) { - return squeezeOrExpandDimensions(tf, labels, predictions, null); - } + /** + * Squeeze or expand last dimension if needed with a sampleWeights of one. + * + *
    + *
  1. Squeezes last dim of predictions or labels if their rank differs by 1 (using + * {@link #removeSqueezableDimensions}). + *
  2. Squeezes or expands last dim of sampleWeight` if its rank differs by 1 from the new + * rank of predictions`. If sampleWeight` is scalar, it is kept scalar./li> + *
+ * + * @param tf the TensorFlow Ops + * @param predictions Predicted values, a Operand of arbitrary dimensions. + * @param labels Optional label Operand whose dimensions match prediction. + * @return Tuple of prediction, label and sampleWeight. Each of them possibly has the last + * dimension squeezed, sampleWeight could be extended by one dimension. If sampleWeight + * is null, (prediction, label) is returned. + */ + public static Tuple squeezeOrExpandDimensions( + Ops tf, Operand labels, Operand predictions) { + return squeezeOrExpandDimensions(tf, labels, predictions, null); + } - /** - * Squeeze or expand last dimension if needed. * * - * - *
    - *
  1. Squeezes last dim of `yPred` or `yTrue` if their rank differs by 1 (using * - * `confusion_matrix.remove_squeezable_dimensions`). * - *
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 from the new * - * rank of `yPred`. If `sample_weight` is scalar, it is kept scalar./li> * - *
- * - * @param tf the TensorFlow Ops - * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. - * @param labels Optional label `Tensor` whose dimensions match `y_pred`. - * @param sampleWeight Optional weight scalar or `Tensor` whose dimensions match `y_pred`. - * @return Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has the last - * dimension squeezed, `sample_weight` could be extended by one dimension. If `sample_weight` - * is null, (y_pred, y_true) is returned. - */ - /********** TODO need to move ConfusionMatrix to MathOps **/ - public static Tuple squeezeOrExpandDimensions( - Ops tf, Operand labels, Operand predictions, Operand sampleWeight) { - Tuple tuple = new Tuple<>(labels, predictions, true); - Shape predictionsShape = predictions.asOutput().shape(); - long ypredRank = predictionsShape.numDimensions(); - - if (labels != null) { - Shape labelsShape = labels.asOutput().shape(); - long ytrueRank = labelsShape.numDimensions(); - if (ytrueRank != Shape.UNKNOWN_SIZE && ypredRank != Shape.UNKNOWN_SIZE) { - // Use static rank for `y_true` and `y_pred`. - if (ypredRank - ytrueRank != 1 || predictionsShape.size(-1) == 1) { - // y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(y_true, y_pred) - tuple = ConfusionMatrix.removeSqueezableDimensions(tf, labels, predictions); - } - } else { // use dynamic rank - tuple = ConfusionMatrix.removeSqueezableDimensions(tf, labels, predictions); - } - } - if (sampleWeight == null) { - return tuple; - } - Shape weightsShape = sampleWeight.asOutput().shape(); - long weightsRank = weightsShape.numDimensions(); - if (weightsRank == 0) { // scalar - return new Tuple(labels, predictions, sampleWeight, true); - } + /** + * Squeeze or expand last dimension if needed. + * + *
    + *
  1. Squeezes last dim of `predictions` or `labels` if their rank differs by 1 (using * + * `confusion_matrix.remove_squeezable_dimensions`). * + *
  2. Squeezes or expands last dim of `sampleWeight` if its rank differs by 1 from the new * + * rank of `predictions`. If `sampleWeight` is scalar, it is kept scalar./li> * + *
+ * + * @param tf the TensorFlow Ops + * @param predictions Predicted values, a Operand of arbitrary dimensions. + * @param labels Optional label Operand whose dimensions match prediction + * . + * @param sampleWeight Optional sample weight(s) Operand whose dimensions match + * prediction. + * @return Tuple of prediction, label and sampleWeight. + * Each of them possibly has the last dimension squeezed, sampleWeight could be + * extended by one dimension. If sampleWeight is null, (prediction, label) is + * returned. + */ + public static Tuple squeezeOrExpandDimensions( + Ops tf, Operand labels, Operand predictions, Operand sampleWeight) { + Tuple tuple = new Tuple<>(labels, predictions); - if (ypredRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { + Shape predictionsShape = predictions.asOutput().shape(); + long predictionsRank = predictionsShape.numDimensions(); - if (weightsRank - ypredRank == 1) { - sampleWeight = tf.squeeze(sampleWeight); - } else if (ypredRank - weightsRank == 1) { - sampleWeight = tf.expandDims(sampleWeight, tf.constant(-1L)); - } - return new Tuple(labels, predictions, sampleWeight, true); + if (labels != null) { + Shape labelsShape = labels.asOutput().shape(); + long labelRank = labelsShape.numDimensions(); + if (labelRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) { + // Use static rank for `label` and `prediction`. + if (predictionsRank - labelRank != 1 || predictionsShape.size(-1) == 1) { + // label, prediction = confusion_matrix.remove_squeezable_dimensions(label, prediction) + tuple = removeSqueezableDimensions(tf, labels, predictions); } - // Use dynamic rank. - Operand weightsRankTensor = tf.rank(sampleWeight); - Operand rankDiff = tf.math.sub(weightsRankTensor, tf.rank(predictions)); - sampleWeight = - tf.select( - tf.math.equal(weightsRankTensor, tf.constant(0)), - sampleWeight, - maybeAdjustWeights(tf, sampleWeight, rankDiff)); - return new Tuple(labels, predictions, sampleWeight, true); + } else { // use dynamic rank + tuple = removeSqueezableDimensions(tf, labels, predictions); + } + } + if (sampleWeight == null) { + return tuple; + } + Shape weightsShape = sampleWeight.asOutput().shape(); + long weightsRank = weightsShape.numDimensions(); + if (weightsRank == 0) { // scalar + return new Tuple<>(labels, predictions, sampleWeight); + } + + if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { + + if (weightsRank - predictionsRank == 1) { + sampleWeight = tf.squeeze(sampleWeight); + } else if (predictionsRank - weightsRank == 1) { + sampleWeight = tf.expandDims(sampleWeight, tf.constant(-1L)); + } + return new Tuple<>(labels, predictions, sampleWeight); + } + // Use dynamic rank. + Operand weightsRankTensor = tf.rank(sampleWeight); + Operand rankDiff = tf.math.sub(weightsRankTensor, tf.rank(predictions)); + sampleWeight = + tf.select( + tf.math.equal(weightsRankTensor, tf.constant(0)), + sampleWeight, + maybeAdjustWeights(tf, sampleWeight, rankDiff)); + return new Tuple<>(labels, predictions, sampleWeight); + } + + /** + * Squeeze or expand the sampleWeight based on the rank difference + * + *

If the rank difference is +1, squeeze the last dimension of sampleWeight, If the rank + * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of + * sampleWeight as is. + * + * @param tf the TensorFlow Ops + * @param sampleWeight the sample weights + * @param rankDiff the difference in rank + * @param the data type for the Operands. + * @return the adjusted sampleWeight + */ + private static Operand maybeAdjustWeights( + Ops tf, Operand sampleWeight, Operand rankDiff) { + return tf.select( + tf.math.equal(rankDiff, tf.constant(1)), + tf.squeeze(sampleWeight, Squeeze.axis(Collections.singletonList(-1L))), + maybeExpandWeights(tf, sampleWeight, rankDiff)); + } + + /** + * Expand the last dimension of sampleWeight. if the rank difference is -1. + * + * @param tf the TensorFlow Ops + * @param sampleWeight the sample weights + * @param rankDiff the difference in rank + * @param the data type for the Operands. + * @return the adjusted sampleWeight + */ + private static Operand maybeExpandWeights( + Ops tf, Operand sampleWeight, Operand rankDiff) { + return tf.select( + tf.math.equal(rankDiff, tf.constant(-1)), + tf.expandDims(sampleWeight, tf.constant(-1)), + sampleWeight); + } + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param tf the TensorFlowOps + * @param labels Label values, a `Tensor` whose dimensions match `predictions`. + * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. + * @return `labels` and `predictions`, possibly with last dim squeezed. + */ + public static Tuple removeSqueezableDimensions( + Ops tf, Operand labels, Operand predictions) { + return removeSqueezableDimensions(tf, labels, predictions, 0); + } + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param tf the TensorFlowOps + * @param labels Label values, a `Tensor` whose dimensions match `predictions`. + * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. + * @param expectedRankDiff Expected result of `rank(predictions) - rank(labels)`. + * @return `labels` and `predictions`, possibly with last dim squeezed. + */ + public static Tuple removeSqueezableDimensions( + Ops tf, Operand labels, Operand predictions, int expectedRankDiff) { + + tf = tf.withSubScope("removeSqueezableDimensions"); + Shape predictionsShape = predictions.asOutput().shape(); + int predictionsRank = predictionsShape.numDimensions(); + Shape labelsShape = labels.asOutput().shape(); + int labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + // Use static rank. + int rankDiff = predictionsRank - labelsRank; + if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { + predictions = tf.squeeze(predictions); + } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { + labels = tf.squeeze(labels); + } + return new Tuple<>(labels, predictions); } + // Use dynamic rank. + // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze + * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L))); + } + if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation labels = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels, + * Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + labels = tf.squeeze(labels, Squeeze.axis(Collections.singletonList(-1L))); + } + return new Tuple<>(labels, predictions); + } - /** - * Squeeze or expand the sampleWeight based on the rank difference - * - *

If the rank difference is +1, squeeze the last dimension of sampleWeight, If the rank - * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of - * sampleWeight as is. - * - * @param tf the TensorFlow Ops - * @param sampleWeight the sample weights - * @param rankDiff the difference in rank - * @param the data type for the Operands. - * @return the adjusted sampleWeight - */ - private static Operand maybeAdjustWeights( - Ops tf, Operand sampleWeight, Operand rankDiff) { - return tf.select( - tf.math.equal(rankDiff, tf.constant(1)), - tf.squeeze(sampleWeight, Squeeze.axis(Arrays.asList(-1L))), - maybeExpandWeights(tf, sampleWeight, rankDiff)); + /** + * Computes the weighted loss + * + * @param tf the TensorFlow Ops + * @param loss the unweighted loss + * @param reduction the type of reduction + * @param sampleWeight the sample weight, if null then this defaults to one. + * @param the data type of the loss + * @return the weighted loss + */ + public static Operand computeWeightedLoss( + Ops tf, Operand loss, Reduction reduction, Operand sampleWeight) { + DataType dataType = loss.asOutput().dataType(); + if (sampleWeight == null) { + sampleWeight = tf.dtypes.cast(tf.constant(1), dataType); } + Tuple result = squeezeOrExpandDimensions(tf, null, loss, sampleWeight); + loss = result.getTarget(); + sampleWeight = result.getSampleWeights(); + + Operand weighted_losses = tf.math.mul(loss, tf.dtypes.cast(sampleWeight, dataType)); + loss = reduceWeightedLoss(tf, weighted_losses, reduction); + return tf.dtypes.cast(loss, dataType); + } - /** - * Expand the last dimension of sampleWeight. if the rank difference is -1. - * - * @param tf the TensorFlow Ops - * @param sampleWeight the sample weights - * @param rankDiff the difference in rank - * @param the data type for the Operands. - * @return the adjusted sampleWeight - */ - private static Operand maybeExpandWeights( - Ops tf, Operand sampleWeight, Operand rankDiff) { - return tf.select( - tf.math.equal(rankDiff, tf.constant(-1)), - tf.expandDims(sampleWeight, tf.constant(-1)), - sampleWeight); + /** + * Reduces the weighted loss based on the reduction type + * + * @param tf the TensorFlow Ops + * @param weightedLoss the weighted loss + * @param reduction the type of reduction + * @param the data type of the weighted loss + * @return the reduced weighted loss + */ + private static Operand reduceWeightedLoss( + Ops tf, Operand weightedLoss, Reduction reduction) { + Operand loss; + if (reduction == Reduction.NONE) { + loss = weightedLoss; + } else { + loss = + tf.reduceSum(weightedLoss, allAxis(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); + if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { + loss = safeMean(tf, loss, weightedLoss.asOutput().shape().size()); + } } + return loss; + } + + /** + * Computes a safe mean of the losses. + * + * @param tf the TensorFlow Ops + * @param losses Operand whose elements contain individual loss measurements. + * @param numElements The number of measurable elements in losses. + * @param the data type of the losses + * @return A scalar representing the mean of losses. If numElements is + * zero, then zero is returned. + */ + public static Operand safeMean( + Ops tf, Operand losses, long numElements) { + Operand totalLoss = tf.reduceSum(losses, allAxis(tf, losses)); + return tf.math.divNoNan( + totalLoss, tf.dtypes.cast(tf.constant(numElements), losses.asOutput().dataType())); + } + /** + * Gets a Constant integer array representing all the axes of the operand. + * + * @param tf the TensorFlow Ops + * @param op the TensorFlow Ops + * @param the type of Operand + * @return a Constant that represents all the axes of the operand. + */ + public static Operand allAxis(Ops tf, Operand op) { + int[] ranks = allAxis(op); + return tf.constant(ranks); + } + + /** + * Gets an integer array representing all the axes of the operand. + * + * @param op the Operand + * @param the type of Operand + * @return the integer array representing all the axes of the operand. + */ + private static int[] allAxis(Operand op) { + int rank = op.asOutput().shape().numDimensions(); + int[] axes = new int[rank]; + for (int i = 0; i < rank; i++) { + axes[i] = i; + } + return axes; + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java index 672c4ca4f6c..402cac96bec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java @@ -4,100 +4,46 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return multiple responses + * A helper class for loss methods to return multiple labels, target, and sampleWeights * * @param the data type of the Tuple entries. */ public class Tuple { private final Operand labels; - private final Operand losses; - private final Operand predictions; + private final Operand target; private final Operand sampleWeights; /** - * Creates a Tuple of Operands for labels, predictions, and sampleWeights + * Creates a Tuple of Operands for labels, target, and sampleWeights * * @param labels the labels - * @param lossesOrPredictions the losses or predictions - * @param isPredictions flag indicating that this Tuple will contain predictions or losses + * @param target the losses or target */ - public Tuple(Operand labels, Operand lossesOrPredictions, boolean isPredictions) { - this(labels, lossesOrPredictions, null, isPredictions); + public Tuple(Operand labels, Operand target) { + this(labels, target, null); } /** - * Creates a Tuple of Operands for labels, predictions, and sampleWeights + * Creates a Tuple of Operands for labels, target, and sampleWeights * * @param labels the labels - * @param lossesOrPredictions the losses or predictions + * @param target the losses or target * @param sampleWeights the sample weights - * @param isPredictions flag indicating that this Tuple will contain predictions or losses */ - public Tuple( - Operand labels, - Operand lossesOrPredictions, - Operand sampleWeights, - boolean isPredictions) { + public Tuple(Operand labels, Operand target, Operand sampleWeights) { this.labels = labels; - if (isPredictions) { - this.predictions = lossesOrPredictions; - this.losses = null; - } else { - this.predictions = null; - this.losses = lossesOrPredictions; - } + this.target = target; this.sampleWeights = sampleWeights; } - /** - * Indicates whether this Tuple contains Labels - * - * @return true is this Tuple contains Labels - */ - public boolean containsLabels() { - return labels != null; - } - - /** - * Indicates whether this Tuple contains Labels - * - * @return true is this Tuple contains Labels - */ - public boolean containsPredictions() { - return predictions != null; - } - - /** - * Indicates whether this Tuple contains Labels - * - * @return true is this Tuple contains Labels - */ - public boolean containsLosses() { - return losses != null; - } - - /** - * Indicates whether this Tuple contains Labels - * - * @return true is this Tuple contains Labels - */ - public boolean containsSampleWeights() { - return this.sampleWeights != null; - } - /** @return the labels */ public Operand getLabels() { return labels; } - /** @return the predictions */ - public Operand getPredictions() { - return predictions; - } - - /** @return the predictions */ - public Operand getLosses() { - return losses; + /** @return the target */ + public Operand getTarget() { + return target; } /** @return the sampleWeights */ diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java new file mode 100644 index 00000000000..83f474f3f88 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -0,0 +1,179 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class BinaryCrossentropyTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class BinaryCrossentropy. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + BinaryCrossentropy instance = new BinaryCrossentropy(tf); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + + Operand loss = instance.call(yTrue, yTrue); + + float expected = 0.0f; + testSession.evaluate(expected, loss); + // Test with logits. + float[] logitsArray = { + 100.0f, -100.0f, -100.0f, + -100.0f, 100.0f, -100.0f, + -100.0f, -100.0f, 100.0f + }; + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new BinaryCrossentropy(tf, true); + + loss = instance.call(yTrue, logits); + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class BinaryCrossentropy. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + BinaryCrossentropy instance = new BinaryCrossentropy(tf); + float[] trueArray = {1f, 0f, 1f, 0f}; + float[] predArray = {1f, 1f, 1f, 0f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred); + float expected = 3.83331f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] trueArray1 = {1f, 0f, 1f, 0f, 1f, 1f}; + float[] logitsArray = { + 100.0f, -100.0f, 100.0f, + 100.0f, 100.0f, -100.0f + }; + Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + instance = new BinaryCrossentropy(tf, true); + loss = instance.call(yTrue1, logits); + expected = 33.33333f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class BinaryCrossentropy. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + BinaryCrossentropy instance = new BinaryCrossentropy(tf); + float[] trueArray = {1f, 0f, 1f, 0f}; + float[] predArray = {1f, 1f, 1f, 0f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 8.816612f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] trueArray1 = {1f, 0f, 1f, 0f, 1f, 1f}; + float[] logitsArray = { + 100.0f, -100.0f, 100.0f, + 100.0f, 100.0f, -100.0f + }; + Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + instance = new BinaryCrossentropy(tf, true); + loss = instance.call(yTrue1, logits, sampleWeight); + expected = 76.66667f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + BinaryCrossentropy instance = new BinaryCrossentropy(tf); + float[] trueArray = {1f, 0f, 1f, 0f}; + float[] predArray = {1f, 1f, 1f, 0f}; + float[] sampleWeightArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 4.59997f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] trueArray1 = {1f, 0f, 1f, 0f, 1f, 1f}; + float[] logitsArray = { + 100.0f, -100.0f, 100.0f, + 100.0f, 100.0f, -100.0f + }; + float[] sampleWeightArray1 = {4f, 3f}; + Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight1 = tf.constant(sampleWeightArray1); + instance = new BinaryCrossentropy(tf, true); + loss = instance.call(yTrue1, logits, sampleWeight1); + expected = 100f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + // Test with logits. + float[] trueArray = {1f, 0f, 1f, 0f, 1f, 1f}; + float[] logitsArray = { + 100.0f, -100.0f, 100.0f, + 100.0f, 100.0f, -100.0f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + BinaryCrossentropy instance = + new BinaryCrossentropy( + tf, true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); + Operand loss = instance.call(yTrue, logits); + Float[] expected = {0.f, 66.666664f}; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testLabelSmoothing() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float labelSmoothing = 0.1f; + float[] trueArray = {1f, 0f, 1f}; + float[] logitsArray = {100.0f, -100.0f, -100.0f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(1, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); + + BinaryCrossentropy instance = new BinaryCrossentropy(tf, true, labelSmoothing); + Operand loss = instance.call(yTrue, logits); + float expected = (100.0f + 50.0f * labelSmoothing) / 3.0f; + testSession.evaluate(expected, loss); + } catch (Exception expected) { + + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java new file mode 100644 index 00000000000..f1bf8f0b2be --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -0,0 +1,213 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class CategoricalCrossentropyTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class CategoricalCrossentropy. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + long[] trueArray = { + 1L, 0L, 0L, + 0L, 1L, 0L, + 0L, 0L, 1L + }; + float[] predArray = { + 1.f, 0.f, 0.f, + 0.f, 1.f, 0.f, + 0.f, 0.f, 1.F + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + Operand loss = instance.call(yTrue, yPred); + float expected = 0f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 10.f, 0.f, 0.f, + 0.f, 10.f, 0.f, + 0.f, 0.f, 10.F + }; + yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new CategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits); + testSession.setEpsilon(1e-3f); + testSession.evaluate(0.0f, loss); + } + } + + /** Test of call method, of class CategoricalCrossentropy. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + float[] predArray = { + .9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.32396814f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 8.f, 1.f, 1.f, + 0.f, 9.f, 1.f, + 2.f, 3.f, 5.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new CategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits); + expected = 0.0573755f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class CategoricalCrossentropy. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + int[] trueArray = { + 1, 0, 0, + 0, 1, 0, + 0, 0, 1 + }; + float[] predArray = { + .9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand sampleWeight = tf.constant(2.3f); + + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = .7451267f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 8.f, 1.f, 1.f, + 0.f, 9.f, 1.f, + 2.f, 3.f, 5.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new CategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits, sampleWeight); + expected = 0.13196386f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSsampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + float[] sampeWeightArray = {1.2f, 3.4f, 5.6f}; + int[] trueArray = { + 1, 0, 0, + 0, 1, 0, + 0, 0, 1 + }; + float[] predArray = { + .9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampeWeightArray), tf.constant(Shape.of(3, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.0696f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 8.f, 1.f, 1.f, + 0.f, 9.f, 1.f, + 2.f, 3.f, 5.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new CategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits, sampleWeight); + expected = 0.31829f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + // Test with logits. + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + float[] logitsArray = { + 8.f, 1.f, 1.f, + 0.f, 9.f, 1.f, + 2.f, 3.f, 5.F + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + CategoricalCrossentropy instance = + new CategoricalCrossentropy(tf, true, 0.0f, Reduction.NONE); + Operand loss = instance.call(yTrue, logits); + Float[] expected = {0.001822f, 0.000459f, 0.169846f}; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testLabelSmoothing() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float labelSmoothing = 0.1f; + int[] trueArray = {1, 0, 0}; + float[] logitsArray = {100.0f, -100.0f, -100.0f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(1, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); + + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf, true, labelSmoothing); + Operand loss = instance.call(yTrue, logits); + float expected = 400.0f * labelSmoothing / 3.0f; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java new file mode 100644 index 00000000000..b455d58740b --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java @@ -0,0 +1,131 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class CategoricalHingeTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class CategoricalHinge. + */ + @Test + public void testReductionNone() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf, Reduction.NONE); + int[] trueArray = {1, 9, 2, -5}; + float[] predArray = {4f, 8f, 12f, 8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred); + Float[] expected = {0.0f, 65.0f}; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class CategoricalHinge. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5}; + float[] predArray = {4f, 8f, 12f, 8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred); + float expected = 32.5f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class CategoricalHinge. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 83.95f; + testSession.evaluate(expected, loss); + + Operand loss2 = instance.call(yTrue, yPred, sampleWeight); + testSession.evaluate(loss, loss2); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] weightsNp = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 124.1f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] weightsNp = {3, 6, 5, 0, 4, 2}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 4.0f; + testSession.evaluate(expected, loss); + + } + } + + +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java new file mode 100644 index 00000000000..ca7aea553d1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java @@ -0,0 +1,171 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class CosineSimilarityTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class CosineSimilarity. */ + @Test + public void testReductionNone() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + CosineSimilarity instance = new CosineSimilarity(tf, Reduction.NONE); + Shape shape = Shape.of(2, 3); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + Operand loss = instance.call(yTrue, yPred); + Float[] expected = {-0.720488f, 0.3460499f}; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class CosineSimilarity. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] expectedLoss = {0.720488f, -0.3460499f}; + CosineSimilarity instance = new CosineSimilarity(tf); + Shape shape = Shape.of(2, 3); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + Operand loss = instance.call(yTrue, yPred); + float expected = -mean(expectedLoss); + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class CosineSimilarity. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] expectedLoss = {0.720488f, -0.3460499f}; + CosineSimilarity instance = new CosineSimilarity(tf); + Shape shape = Shape.of(2, 3); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -mean(mul(expectedLoss, 2.3f)); + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] expectedLoss = {0.720488f, -0.3460499f}; + CosineSimilarity instance = new CosineSimilarity(tf); + float[] weightsArray = {1.2f, 3.4f}; + Shape shape = Shape.of(2, 3); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + Operand sampleWeight = + tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -mean(mul(expectedLoss, weightsArray)); + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CosineSimilarity instance = new CosineSimilarity(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Shape shape = Shape.of(2, 3); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + Operand sampleWeight = tf.constant(0f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + CosineSimilarity instance = new CosineSimilarity(tf); + Shape shape = Shape.of(2, 3, 1); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + float[] weightsArray = {3, 6, 5, 0, 4, 2}; + Operand sampleWeight = + tf.reshape(tf.constant(weightsArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -2.0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testAxis() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] expectedLoss = {0.720488f, -0.3460499f}; + CosineSimilarity instance = new CosineSimilarity(tf, 1); + Shape shape = Shape.of(2, 3); + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(shape)); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(shape)); + Operand loss = instance.call(yTrue, yPred); + float expected = -mean(expectedLoss); + testSession.evaluate(expected, loss); + } + } + + private float mean(float[] v) { + float sum = 0; + for (float value : v) { + sum += value; + } + return sum / v.length; + } + + private float[] mul(float[] v, float scalar) { + float[] result = new float[v.length]; + for (int i = 0; i < v.length; i++) { + result[i] = v[i] * scalar; + } + return result; + } + + private float[] mul(float[] v, float[] b) { + float[] result = new float[v.length]; + for (int i = 0; i < v.length; i++) { + result[i] = v[i] * b[i]; + } + return result; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java new file mode 100644 index 00000000000..1f13d0392d7 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -0,0 +1,108 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class HingeTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class Hinge. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.50625f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class Hinge. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.164375f; + testSession.evaluate(expected, loss); + + // todo Verify we get the same output when the same input is given + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] sampleArray = {1.2f, 3.4f}; + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.06125f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf, Reduction.AUTO); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 2.0125f; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java new file mode 100644 index 00000000000..d7acab0126f --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java @@ -0,0 +1,123 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class HuberTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testAllCorrect() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; + + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Huber instance = new Huber(tf); + Operand loss = instance.call(yTrue, yTrue); + float expected = 0.0f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class Huber. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; + float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Huber instance = new Huber(tf); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.10416666666666669f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class Huber. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; + float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Huber instance = new Huber(tf); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.23958333333333337f; + testSession.evaluate(expected, loss); + + // todo Verify we get the same output when the same input is given + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] sampleArray = {1.2f, 3.4f}; + float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; + float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Huber instance = new Huber(tf); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.22766666666666668f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; + float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Huber instance = new Huber(tf); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + float[] trueArray = {.9f, .2f, .2f, .8f, .4f, .6f}; + float[] predArray = {1.f, 0.f, 1.f, 1.f, 0.f, 0.f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Huber instance = new Huber(tf); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = .4025f; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java new file mode 100644 index 00000000000..d40b4286f3d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java @@ -0,0 +1,106 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class KLDivergenceTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class KLDivergence. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.5960738398643668f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class KLDivergence. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.3709698316880434f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 2.0075711736936492f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf, Reduction.AUTO); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 0.2495994912084345f; + testSession.evaluate(expected, loss); + } + } + +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java new file mode 100644 index 00000000000..0828471062a --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java @@ -0,0 +1,105 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class LogCoshTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class LogCosh. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 4.829245330860459f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class LogCosh. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 11.107264260979056f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 12.001114667519486f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf, Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 11.653484271934046f; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java new file mode 100644 index 00000000000..91747900c2c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java @@ -0,0 +1,180 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class MeanAbsoluteErrorTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class MeanAbsoluteError. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yTrue); + float expected = 0.0f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class MeanAbsoluteError. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 5.5f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class MeanAbsoluteError. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 12.65f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 81.4f / 6f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 83f / 6f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testInvalidSampleWeight() { + for (TestSession.Mode tfMode : tfModes) + Assertions.assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 83f / 6f; + testSession.evaluate(expected, loss); + } + }); + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {10.733333f, 14.566667f}; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSumReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError(tf, Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {25.29999f}; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java new file mode 100644 index 00000000000..5c2521f900c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java @@ -0,0 +1,153 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class MeanAbsolutePercentageErrorTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class MeanAbsolutePercentageError. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yTrue); + float expected = 0.0f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class MeanAbsolutePercentageError. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 211.85184f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class MeanAbsolutePercentageError. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 487.25922f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 422.8889f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 694.4445f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {621.8518f, 352.66666f}; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSumReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError(tf, Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 974.51843f; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java new file mode 100644 index 00000000000..02ef7621e38 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java @@ -0,0 +1,180 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class MeanSquaredErrorTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class MeanSquaredError. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yTrue); + float expected = 0.0f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class MeanSquaredError. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 49.5f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class MeanSquaredError. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 113.85f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 127.96667f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf, Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 97.833336f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testInvalidSampleWeight() { + for (TestSession.Mode tfMode : tfModes) + Assertions.assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 173.25f; + testSession.evaluate(expected, loss); + } + }); + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf, Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {84.333336f, 143.36665f}; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSumReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredError instance = new MeanSquaredError(tf, Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {227.69998f}; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java new file mode 100644 index 00000000000..66caa215219 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java @@ -0,0 +1,179 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class MeanSquaredLogarithmicErrorTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class MeanSquaredLogarithmicError. + */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yTrue); + float expected = 0.0f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class MeanSquaredLogarithmicError. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 1.4370421f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class MeanSquaredLogarithmicError. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 3.3051968f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 3.7856376f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 2.647374f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testInvalidSampleWeight() { + for (TestSession.Mode tfMode : tfModes) + Assertions.assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 83f / 6f; + testSession.evaluate(expected, loss); + }}); + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {2.3006392f, 4.3097544f}; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSumReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {6.6103935f}; + testSession.evaluate(expected, loss); + } + } + +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java new file mode 100644 index 00000000000..0a086a37b96 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java @@ -0,0 +1,105 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class PoissonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class Poisson. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = -3.306581945521002f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class Poisson. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -7.605138474698304f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -6.147338926788071f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf, Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = -12.263126013890561f; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java new file mode 100644 index 00000000000..89b96bad198 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -0,0 +1,180 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class SparseCategoricalCrossentropyTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class SparseCategoricalCrossentropy. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + testSession.setEpsilon(1e-3f); + Ops tf = testSession.getTF(); + + int[] trueArray = {0, 1, 2}; + float[] predArray = { + 1.F, 0.F, 0.F, + 0.F, 1.F, 0.F, + 0.F, 0.F, 1.F + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.0f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 10.F, 0.F, 0.F, + 0.F, 10.F, 0.F, + 0.F, 0.F, 10.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new SparseCategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits); + testSession.evaluate(0.0f, loss); + } + } + + /** Test of call method, of class SparseCategoricalCrossentropy. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + int[] trueArray = {0, 1, 2}; + float[] predArray = { + .9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.32396814f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new SparseCategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits); + expected = 0.05737559f; + testSession.evaluate(expected, loss); + } + } + + /** Test of call method, of class SparseCategoricalCrossentropy. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + int[] trueArray = {0, 1, 2}; + float[] predArray = { + .9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand sampleWeight = tf.constant(2.3f); + + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = .7451267f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new SparseCategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits, sampleWeight); + expected = 0.13196386f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + float[] sampleWeightArray = {1.2f, 3.4f, 5.6f}; + int[] trueArray = {0, 1, 2}; + float[] predArray = { + .9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleWeightArray), tf.constant(Shape.of(3, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.0696f; + testSession.evaluate(expected, loss); + + // Test with logits. + float[] logitsArray = { + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F + }; + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + instance = new SparseCategoricalCrossentropy(tf, true); + loss = instance.call(yTrue, logits, sampleWeight); + expected = 0.31829f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + // Test with logits. + long[] trueArray = {0L, 1L, 2L}; + float[] logitsArray = { + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy(tf, true, Reduction.NONE); + Operand loss = instance.call(yTrue, logits); + Float[] expected = {0.001822f, 0.000459f, 0.169846f}; + testSession.evaluate(expected, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java new file mode 100644 index 00000000000..19236f50749 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -0,0 +1,105 @@ +package org.tensorflow.framework.losses; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +public class SquaredHingeTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** + * Test of call method, of class SquaredHinge. + */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.364062f; + testSession.evaluate(expected, loss); + } + } + + /** + * Test of call method, of class SquaredHinge. + */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.8373437f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] sampleArray = {1.2f, 3.4f}; + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.7043125f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf, Reduction.AUTO); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; + Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + + float expected = 1.54250000f; + testSession.evaluate(expected, loss); + } + } +} From 17e96b5ab78ec7d4d87f24b0f8f97a54c3e9e882 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 8 Oct 2020 18:25:50 -0400 Subject: [PATCH 03/26] Fix reshape in sparseCategoricalCrossentropy() --- .../tensorflow/framework/losses/Losses.java | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 604ffc1b474..9eeab1357e6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -511,14 +511,14 @@ public static Operand sparseCategorica predictions = tf.clipByValue(predictions, epsilonConst, oneMinusEpsilonConst); predictions = tf.math.log(predictions); } - Shape outputShape = predictions.asOutput().shape(); - int outputRank = outputShape.numDimensions(); - axis %= outputRank; + Shape predictionsShape = predictions.asOutput().shape(); + int predictionsRank = predictionsShape.numDimensions(); + axis %= predictionsRank; if (axis < 0) { - axis += outputRank; + axis += predictionsRank; } - if (axis != outputRank - 1) { - int[] axisNew = moveAxisToEnd(axis, outputRank); + if (axis != predictionsRank - 1) { + int[] axisNew = moveAxisToEnd(axis, predictionsRank); predictions = tf.linalg.transpose(predictions, tf.constant(axisNew)); } @@ -528,20 +528,21 @@ public static Operand sparseCategorica Shape labelsShape = labels.asOutput().shape(); int labelsRank = labelsShape.numDimensions(); - boolean updateShape = labelsRank != outputRank - 1; + boolean updateShape = labelsRank != predictionsRank - 1; if (updateShape) { // TODO check to see if this is right - iLabels = tf.reshape(iLabels, tf.constant(-1)); // flatten one dimension + Shape newShape = labelsShape.take(labelsRank-1); + iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = tf.reshape( predictions, - tf.constant(new long[] {-1L, outputShape.size(outputShape.numDimensions() - 1)})); + tf.constant(new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } @SuppressWarnings("unchecked") Operand loss = tf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); - if (updateShape && outputRank >= 3) { - Shape newShape = outputShape.take(outputShape.numDimensions() - 1); + if (updateShape && predictionsRank >= 3) { + Shape newShape = predictionsShape.take(predictionsShape.numDimensions() - 1); loss = tf.reshape(loss, tf.constant(newShape)); } return loss; From ee1c48a443810260be7319caab94bde8a3dae529 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 11 Oct 2020 12:05:56 -0400 Subject: [PATCH 04/26] Apply various fixes to JavaDoc --- .../framework/losses/BinaryCrossentropy.java | 6 +- .../losses/CategoricalCrossentropy.java | 16 ++-- .../framework/losses/CategoricalHinge.java | 8 +- .../framework/losses/CosineSimilarity.java | 4 +- .../tensorflow/framework/losses/Hinge.java | 6 +- .../tensorflow/framework/losses/Huber.java | 4 +- .../framework/losses/KLDivergence.java | 6 +- .../tensorflow/framework/losses/LogCosh.java | 8 +- .../org/tensorflow/framework/losses/Loss.java | 3 +- .../tensorflow/framework/losses/Losses.java | 18 ++-- .../framework/losses/MeanAbsoluteError.java | 6 +- .../losses/MeanAbsolutePercentageError.java | 6 +- .../framework/losses/MeanSquaredError.java | 6 +- .../losses/MeanSquaredLogarithmicError.java | 6 +- .../tensorflow/framework/losses/Poisson.java | 8 +- .../losses/SparseCategoricalCrossentropy.java | 4 +- .../framework/losses/SquaredHinge.java | 8 +- .../framework/losses/impl/LossesImpl.java | 86 ++++++++++--------- 18 files changed, 111 insertions(+), 98 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index aa4e167c149..c8be0463403 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -59,7 +59,7 @@ public class BinaryCrossentropy extends Loss { /** * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a - * Loss Reduction of {@link * Reduction#AUTO} + * Loss Reduction of {@link Reduction#AUTO} * * * @@ -173,7 +173,7 @@ public BinaryCrossentropy( public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = - Losses.binaryCrossentropy(tf, labels, predictions, fromLogits, labelSmoothing); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index b042a656405..a7491285a68 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -63,7 +63,7 @@ public class CategoricalCrossentropy extends Loss { /** * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and an axis of {@link + * labelSmoothing, a Loss Reduction of {@link Reduction#AUTO}, and an axis of {@link * #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -74,7 +74,7 @@ public CategoricalCrossentropy(Ops tf) { /** * Creates a categorical cross entropy Loss using {@link #FROM_LOGITS_DEFAULT} for fromLogits, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link * + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and an axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -110,7 +110,7 @@ public CategoricalCrossentropy(Ops tf, String name, Reduction reduction) { /** * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link * + * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and an axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -122,7 +122,7 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits) { /** * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and a channel axis of {@link + * labelSmoothing, a Loss Reduction of {@link Reduction#AUTO}, and a channel axis of {@link * #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -135,7 +135,7 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * a Loss Reduction of {@link * Reduction#AUTO}, and a channel axis of {@link #DEFAULT_AXIS} + * a Loss Reduction of {@link Reduction#AUTO}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -149,7 +149,7 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link * Reduction#AUTO}, + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Reduction#AUTO}, * and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -213,7 +213,7 @@ public CategoricalCrossentropy( public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = - Losses.categoricalCrossentropy(tf, labels, predictions, fromLogits, labelSmoothing, axis); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Losses.categoricalCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 6c828fd2d16..2b1f1e044f9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -11,6 +11,8 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * + *

labels values are expected to be 0 or 1.

+ * *

Standalone usage: * *

@@ -52,7 +54,7 @@ public class CategoricalHinge extends Loss {
 
   /**
    * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name and a
-   * Loss Reduction of {@link * Reduction#AUTO}
+   * Loss Reduction of {@link Reduction#AUTO}
    *
    * @param tf the TensorFlow Ops
    */
@@ -85,7 +87,7 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) {
   @Override
   public  Operand call(
           Operand labels, Operand predictions, Operand sampleWeights) {
-    Operand losses = Losses.categoricalHinge(tf, labels, predictions);
-    return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights);
+    Operand losses = Losses.categoricalHinge(getTF(), labels, predictions);
+    return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java
index e5d9d6a5d7b..26b5fdb6ee9 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java
@@ -158,7 +158,7 @@ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) {
   @Override
   public  Operand call(
           Operand labels, Operand predictions, Operand sampleWeights) {
-    Operand losses = Losses.cosineSimilarity(tf, labels, predictions, axis);
-    return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights);
+    Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis);
+    return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
index 5209d8df360..ad23901d2f1 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
@@ -53,7 +53,7 @@ public class Hinge extends Loss {
 
   /**
    * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction
-   * of {@link * Reduction#AUTO}
+   * of {@link Reduction#AUTO}
    *
    * @param tf the TensorFlow Ops
    */
@@ -86,7 +86,7 @@ public Hinge(Ops tf, String name, Reduction reduction) {
   @Override
   public  Operand call(
           Operand labels, Operand predictions, Operand sampleWeights) {
-    Operand losses = Losses.hinge(tf, labels, predictions);
-    return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights);
+    Operand losses = Losses.hinge(getTF(), labels, predictions);
+    return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java
index 3b2949eeb03..261a0439f83 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java
@@ -118,7 +118,7 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) {
   @Override
   public  Operand call(
           Operand labels, Operand predictions, Operand sampleWeights) {
-    Operand losses = Losses.huber(tf, labels, predictions, delta);
-    return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights);
+    Operand losses = Losses.huber(getTF(), labels, predictions, delta);
+    return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java
index 71c348069bd..6c9371ff8fd 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java
@@ -53,7 +53,7 @@ public class KLDivergence extends Loss {
 
   /**
    * Creates a Kullback Leibler Divergence Loss using {@link Class#getSimpleName()} as the loss name
-   * and a Loss Reduction of {@link * Reduction#AUTO}
+   * and a Loss Reduction of {@link Reduction#AUTO}
    *
    * @param tf the TensorFlow Ops
    */
@@ -87,7 +87,7 @@ public KLDivergence(Ops tf, String name, Reduction reduction) {
   @Override
   public  Operand call(
           Operand labels, Operand predictions, Operand sampleWeights) {
-    Operand losses = Losses.kullbackLeiblerDivergence(tf, labels, predictions);
-    return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights);
+    Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions);
+    return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
index 6ddb0b2daac..a0e99180ef8 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
@@ -51,7 +51,7 @@ public class LogCosh extends Loss {
 
   /**
    * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name and a Loss
-   * Reduction of {@link * Reduction#AUTO}
+   * Reduction of {@link Reduction#AUTO}
    *
    * @param tf the TensorFlow Ops
    */
@@ -60,7 +60,7 @@ public LogCosh(Ops tf) {
   }
 
   /**
-   * Creates a LogCosh Loss using a Loss Reduction of {@link * Reduction#AUTO}
+   * Creates a LogCosh Loss using a Loss Reduction of {@link Reduction#AUTO}
    *
    * @param tf the TensorFlow Ops
    */
@@ -93,7 +93,7 @@ public LogCosh(Ops tf, String name, Reduction reduction) {
   @Override
   public  Operand call(
           Operand labels, Operand predictions, Operand sampleWeights) {
-    Operand losses = Losses.logCosh(tf, labels, predictions);
-    return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights);
+    Operand losses = Losses.logCosh(getTF(), labels, predictions);
+    return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
   }
 }
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java
index 9c0976f2c6f..445ebf99565 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java
@@ -45,7 +45,8 @@ protected Loss(Ops tf, String name, Reduction reduction) {
    *
    * @param labels the truth values or labels
    * @param predictions the predictions
-   * @param  The data type of the labels, predictions and loss.
+   * @param  The data type of the predictions and loss.
+   * @param  The data type of the labels.
    * @return the loss
    */
   public  Operand call(Operand labels, Operand predictions) {
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
index 9eeab1357e6..e3e91a41bc1 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
@@ -32,7 +32,7 @@ public class Losses {
    * @param tf The TensorFlow Ops
    * @param labels the labels
    * @param predictions the predictions
-   * @param  the data type of the result
+   * @param  the data type of the predictions and result
    * @param  the data type of the labels
    * @return the mean absolute error
    */
@@ -54,7 +54,7 @@ public static  Operand meanAbsoluteErro
    * @param tf The TensorFlow Ops
    * @param labels the labels
    * @param predictions the predictions
-   * @param  the data type of the result
+   * @param  the data type of the predictions and result
    * @param  the data type of the labels
    * @return the mean squared error
    */
@@ -75,14 +75,14 @@ public static  Operand meanSquaredError
    * @param tf The TensorFlow Ops
    * @param labels the labels
    * @param predictions the predictions
-   * @param  the data type of the result
+   * @param  the data type of the predictions and result
    * @param  the data type of the labels
    * @return the mean absolute percentage error
    */
   public static  Operand meanAbsolutePercentageError(
       Ops tf, Operand labels, Operand predictions) {
     DataType dataType = predictions.asOutput().dataType();
-    Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType());
+    Operand tLabels = tf.dtypes.cast(labels,dataType);
     Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null);
     predictions = ops.getTarget();
     tLabels = ops.getLabels();
@@ -104,7 +104,7 @@ public static  Operand meanAbsolutePerc
    * @param tf The TensorFlow Ops
    * @param labels the labels
    * @param predictions the predictions
-   * @param  the data type of the result
+   * @param  the data type of the predictions and result
    * @param  the data type of the labels
    * @return the mean squared logarithmic percentage error
    */
@@ -262,7 +262,7 @@ public static  Operand categoricalCross
    * Computes the categorical hinge loss between labels and predictions.
    *
    * @param tf the TensorFlow Ops
-   * @param labels true targets
+   * @param labels true targets,  values are expected to be 0 or 1.
    * @param predictions the predictions
    * @param  the data type of the predictions and labels
    * @return the categorical hinge loss
@@ -330,7 +330,8 @@ public static  Operand cosineSimilarity
    * 

loss = reduceMean(maximum(1 - labels * predictions, 0)) * * @param tf the TensorFlow Ops - * @param labels true targets + * @param labels true targets, values are expected to be -1 or 1. If binary (0 or 1) labels are + * provided, they will be converted to -1 or 1. * @param predictions the predictions * @param the data type of the predictions and labels * @return the hinge loss @@ -554,7 +555,8 @@ public static Operand sparseCategorica *

loss = reduceMean(square(maximum(1 - labels * predictions, 0))) * * @param tf the TensorFlow Ops - * @param labels true targets + * @param labels true targets, values are expected to be -1 or 1. If binary (0 or 1) labels are * + * provided, they will be converted to -1 or 1. * @param predictions the predictions * @param the data type of the predictions and labels * @return the squared hinge loss diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 31592f8188b..6aa39218ac8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -50,7 +50,7 @@ public class MeanAbsoluteError extends Loss { /** * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link * Reduction#AUTO} + * Loss Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -83,7 +83,7 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsoluteError(tf, labels, predictions); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 7e2ab3fa8ae..73bb62e4686 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -50,7 +50,7 @@ public class MeanAbsolutePercentageError extends Loss { /** * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link * Reduction#AUTO} + * and a Loss Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -83,7 +83,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanAbsolutePercentageError(tf, labels, predictions); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 1b892d2fa16..6e5160cd288 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -50,7 +50,7 @@ public class MeanSquaredError extends Loss { /** * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link * Reduction#AUTO} + * Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -83,7 +83,7 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredError(tf, labels, predictions); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 4efe1fb0b7b..8b12c2b525e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -50,7 +50,7 @@ public class MeanSquaredLogarithmicError extends Loss { /** * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link * Reduction#AUTO} + * Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -83,7 +83,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.meanSquaredLogarithmicError(tf, labels, predictions); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index c221f49eb90..66dd356ab33 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -51,7 +51,7 @@ public class Poisson extends Loss { /** * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link * Reduction#AUTO} + * Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -60,7 +60,7 @@ public Poisson(Ops tf) { } /** - * Creates a Poisson Loss using a Loss Reduction of {@link * Reduction#AUTO} + * Creates a Poisson Loss using a Loss Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -93,7 +93,7 @@ public Poisson(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.poisson(tf, labels, predictions); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Operand losses = Losses.poisson(getTF(), labels, predictions); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index f4af06e64a6..7776e23e40b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -164,7 +164,7 @@ public SparseCategoricalCrossentropy( public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = - Losses.sparseCategoricalCrossentropy(tf, labels, predictions, fromLogits, axis); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index a8ec49be835..611b93a1a98 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -10,7 +10,7 @@ * *

loss = square(maximum(1 - labels * predictions, 0)) * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be * converted to -1 or 1. * *

Standalone usage: @@ -54,7 +54,7 @@ public class SquaredHinge extends Loss { /** * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link * Reduction#AUTO} + * Reduction of {@link Reduction#AUTO} * * @param tf the TensorFlow Ops */ @@ -87,7 +87,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.squaredHinge(tf, labels, predictions); - return LossesImpl.computeWeightedLoss(tf, losses, getReduction(), sampleWeights); + Operand losses = Losses.squaredHinge(getTF(), labels, predictions); + return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index eb8032569c2..089e264e04a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -18,18 +18,21 @@ public class LossesImpl { * Squeeze or expand last dimension if needed with a sampleWeights of one. * *

    - *
  1. Squeezes last dim of predictions or labels if their rank differs by 1 (using - * {@link #removeSqueezableDimensions}). - *
  2. Squeezes or expands last dim of sampleWeight` if its rank differs by 1 from the new - * rank of predictions`. If sampleWeight` is scalar, it is kept scalar./li> + *
  3. Squeezes last dim of predictions or labels if their rank + * differs by 1 (using {@link #removeSqueezableDimensions}). + *
  4. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from + * the new rank of predictions. If sampleWeight is scalar, it is + * kept scalar./li> *
* * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. - * @param labels Optional label Operand whose dimensions match prediction. - * @return Tuple of prediction, label and sampleWeight. Each of them possibly has the last - * dimension squeezed, sampleWeight could be extended by one dimension. If sampleWeight - * is null, (prediction, label) is returned. + * @param labels Optional label Operand whose dimensions match prediction + * . + * @return Tuple of prediction, label,sampleWeight will be + * null. Each of them possibly has the last dimension squeezed, sampleWeight + * could be extended by one dimension. If sampleWeight is null, (prediction, + * label) is returned. */ public static Tuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions) { @@ -40,36 +43,39 @@ public static Tuple squeezeOrExpandDimensions( * Squeeze or expand last dimension if needed. * *
    - *
  1. Squeezes last dim of `predictions` or `labels` if their rank differs by 1 (using * - * `confusion_matrix.remove_squeezable_dimensions`). * - *
  2. Squeezes or expands last dim of `sampleWeight` if its rank differs by 1 from the new * - * rank of `predictions`. If `sampleWeight` is scalar, it is kept scalar./li> * + *
  3. Squeezes last dim of predictions or labels if their rank do not + * differ by 1. + *
  4. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from + * the new rank of predictions. If sampleWeight is scalar, it is + * kept scalar. *
* * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param sampleWeight Optional sample weight(s) Operand whose dimensions match + * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * prediction. - * @return Tuple of prediction, label and sampleWeight. + * @return Tuple of prediction, labels and sampleWeight. * Each of them possibly has the last dimension squeezed, sampleWeight could be - * extended by one dimension. If sampleWeight is null, (prediction, label) is + * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are * returned. */ public static Tuple squeezeOrExpandDimensions( - Ops tf, Operand labels, Operand predictions, Operand sampleWeight) { - Tuple tuple = new Tuple<>(labels, predictions); + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + Shape predictionsShape = predictions.asOutput().shape(); long predictionsRank = predictionsShape.numDimensions(); + // Default case when no modifications are made. + Tuple tuple = new Tuple<>(labels, predictions, sampleWeights); if (labels != null) { Shape labelsShape = labels.asOutput().shape(); - long labelRank = labelsShape.numDimensions(); - if (labelRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) { - // Use static rank for `label` and `prediction`. - if (predictionsRank - labelRank != 1 || predictionsShape.size(-1) == 1) { + long labelsRank = labelsShape.numDimensions(); + if (labelsRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) { + // Use static rank for 'label' and 'prediction'. + if (predictionsRank - labelsRank != 1 || predictionsShape.size(-1) == 1) { // label, prediction = confusion_matrix.remove_squeezable_dimensions(label, prediction) tuple = removeSqueezableDimensions(tf, labels, predictions); } @@ -77,33 +83,33 @@ public static Tuple squeezeOrExpandDimensions( tuple = removeSqueezableDimensions(tf, labels, predictions); } } - if (sampleWeight == null) { + if (sampleWeights == null) { // nothing more to do. return tuple; } - Shape weightsShape = sampleWeight.asOutput().shape(); + Shape weightsShape = sampleWeights.asOutput().shape(); long weightsRank = weightsShape.numDimensions(); if (weightsRank == 0) { // scalar - return new Tuple<>(labels, predictions, sampleWeight); + return new Tuple<>(labels, predictions, sampleWeights); } if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { if (weightsRank - predictionsRank == 1) { - sampleWeight = tf.squeeze(sampleWeight); + sampleWeights = tf.squeeze(sampleWeights); } else if (predictionsRank - weightsRank == 1) { - sampleWeight = tf.expandDims(sampleWeight, tf.constant(-1L)); + sampleWeights = tf.expandDims(sampleWeights, tf.constant(-1L)); } - return new Tuple<>(labels, predictions, sampleWeight); + return new Tuple<>(labels, predictions, sampleWeights); } // Use dynamic rank. - Operand weightsRankTensor = tf.rank(sampleWeight); + Operand weightsRankTensor = tf.rank(sampleWeights); Operand rankDiff = tf.math.sub(weightsRankTensor, tf.rank(predictions)); - sampleWeight = + sampleWeights = tf.select( tf.math.equal(weightsRankTensor, tf.constant(0)), - sampleWeight, - maybeAdjustWeights(tf, sampleWeight, rankDiff)); - return new Tuple<>(labels, predictions, sampleWeight); + sampleWeights, + maybeAdjustWeights(tf, sampleWeights, rankDiff)); + return new Tuple<>(labels, predictions, sampleWeights); } /** @@ -148,9 +154,10 @@ private static Operand maybeExpandWeights( * Squeeze last dim if ranks differ from expected by exactly 1. * * @param tf the TensorFlowOps - * @param labels Label values, a `Tensor` whose dimensions match `predictions`. - * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. - * @return `labels` and `predictions`, possibly with last dim squeezed. + * @param labels Label values, a Tensor whose dimensions match predictions + * . + * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @return labels and predictions, possibly with last dim squeezed. */ public static Tuple removeSqueezableDimensions( Ops tf, Operand labels, Operand predictions) { @@ -161,10 +168,11 @@ public static Tuple removeSqueezableDimensions( * Squeeze last dim if ranks differ from expected by exactly 1. * * @param tf the TensorFlowOps - * @param labels Label values, a `Tensor` whose dimensions match `predictions`. - * @param predictions Predicted values, a `Tensor` of arbitrary dimensions. - * @param expectedRankDiff Expected result of `rank(predictions) - rank(labels)`. - * @return `labels` and `predictions`, possibly with last dim squeezed. + * @param labels Label values, a Operand whose dimensions match predictions + * . + * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @return labels and predictions, possibly with last dim squeezed. */ public static Tuple removeSqueezableDimensions( Ops tf, Operand labels, Operand predictions, int expectedRankDiff) { From 287c96e34eea177303716e6a2b72509c2c749333 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 11 Oct 2020 12:46:18 -0400 Subject: [PATCH 05/26] Change Tuple to LossTuple --- .../tensorflow/framework/losses/Losses.java | 62 +++++++++---------- .../impl/{Tuple.java => LossTuple.java} | 14 ++--- .../framework/losses/impl/LossesImpl.java | 32 +++++----- 3 files changed, 54 insertions(+), 54 deletions(-) rename tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/{Tuple.java => LossTuple.java} (65%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index e3e91a41bc1..36c04fb2df4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -3,7 +3,7 @@ import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossesImpl; -import org.tensorflow.framework.losses.impl.Tuple; +import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceAll; @@ -39,7 +39,7 @@ public class Losses { public static Operand meanAbsoluteError( Ops tf, Operand labels, Operand predictions) { Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); - Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); return tf.math.mean( @@ -61,7 +61,7 @@ public static Operand meanAbsoluteErro public static Operand meanSquaredError( Ops tf, Operand labels, Operand predictions) { Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); - Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); return tf.math.mean(tf.math.squaredDifference(predictions, tLabels), tf.constant(-1)); @@ -83,7 +83,7 @@ public static Operand meanAbsolutePerc Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels,dataType); - Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); Operand diff = @@ -112,7 +112,7 @@ public static Operand meanSquaredLogar Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); - Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -143,7 +143,7 @@ public static Operand binaryCrossentro Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -222,7 +222,7 @@ public static Operand categoricalCross int axis) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -271,9 +271,9 @@ public static Operand categoricalHinge Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand zero = tf.dtypes.cast(tf.constant(0), dataType); @@ -313,9 +313,9 @@ public static Operand cosineSimilarity Ops tf, Operand labels, Operand predictions, int axis) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); tLabels = l2Normalize(tf, tLabels, axis); predictions = l2Normalize(tf, predictions, axis); @@ -340,9 +340,9 @@ public static Operand hinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand zero = tf.dtypes.cast(tf.constant(0), dataType); @@ -376,9 +376,9 @@ public static Operand huber( Ops tf, Operand labels, Operand predictions, float delta) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand error = tf.math.sub(predictions, tLabels); Operand deltaConst = tf.dtypes.cast(tf.constant(delta), dataType); @@ -407,9 +407,9 @@ public static Operand kullbackLeiblerD Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); @@ -437,9 +437,9 @@ public static Operand logCosh( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand minusTwo = tf.dtypes.cast(tf.constant(-2), dataType); Operand two = tf.dtypes.cast(tf.constant(2), dataType); @@ -465,9 +465,9 @@ public static Operand poisson( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); return tf.math.mean( @@ -565,9 +565,9 @@ public static Operand squaredHinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = tf.dtypes.cast(labels, dataType); - Tuple tuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); - predictions = tuple.getTarget(); - tLabels = tuple.getLabels(); + LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + predictions = lossTuple.getTarget(); + tLabels = lossTuple.getLabels(); Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand zero = tf.dtypes.cast(tf.constant(0), dataType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java similarity index 65% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 402cac96bec..596fb31c0d5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/Tuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -4,33 +4,33 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return multiple labels, target, and sampleWeights + * A helper class for loss methods to return labels, target, and sampleWeights * - * @param the data type of the Tuple entries. + * @param the data type of the LossTuple entries. */ -public class Tuple { +public class LossTuple { private final Operand labels; private final Operand target; private final Operand sampleWeights; /** - * Creates a Tuple of Operands for labels, target, and sampleWeights + * Creates a LossTuple of Operands for labels, target, and sampleWeights * * @param labels the labels * @param target the losses or target */ - public Tuple(Operand labels, Operand target) { + public LossTuple(Operand labels, Operand target) { this(labels, target, null); } /** - * Creates a Tuple of Operands for labels, target, and sampleWeights + * Creates a LossTuple of Operands for labels, target, and sampleWeights * * @param labels the labels * @param target the losses or target * @param sampleWeights the sample weights */ - public Tuple(Operand labels, Operand target, Operand sampleWeights) { + public LossTuple(Operand labels, Operand target, Operand sampleWeights) { this.labels = labels; this.target = target; this.sampleWeights = sampleWeights; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 089e264e04a..e483a305be0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -29,12 +29,12 @@ public class LossesImpl { * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @return Tuple of prediction, label,sampleWeight will be + * @return LossTuple of prediction, label,sampleWeight will be * null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, * label) is returned. */ - public static Tuple squeezeOrExpandDimensions( + public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions) { return squeezeOrExpandDimensions(tf, labels, predictions, null); } @@ -56,12 +56,12 @@ public static Tuple squeezeOrExpandDimensions( * . * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * prediction. - * @return Tuple of prediction, labels and sampleWeight. + * @return LossTuple of prediction, labels and sampleWeight. * Each of them possibly has the last dimension squeezed, sampleWeight could be * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are * returned. */ - public static Tuple squeezeOrExpandDimensions( + public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -69,7 +69,7 @@ public static Tuple squeezeOrExpandDimensions( long predictionsRank = predictionsShape.numDimensions(); // Default case when no modifications are made. - Tuple tuple = new Tuple<>(labels, predictions, sampleWeights); + LossTuple lossTuple = new LossTuple<>(labels, predictions, sampleWeights); if (labels != null) { Shape labelsShape = labels.asOutput().shape(); long labelsRank = labelsShape.numDimensions(); @@ -77,19 +77,19 @@ public static Tuple squeezeOrExpandDimensions( // Use static rank for 'label' and 'prediction'. if (predictionsRank - labelsRank != 1 || predictionsShape.size(-1) == 1) { // label, prediction = confusion_matrix.remove_squeezable_dimensions(label, prediction) - tuple = removeSqueezableDimensions(tf, labels, predictions); + lossTuple = removeSqueezableDimensions(tf, labels, predictions); } } else { // use dynamic rank - tuple = removeSqueezableDimensions(tf, labels, predictions); + lossTuple = removeSqueezableDimensions(tf, labels, predictions); } } if (sampleWeights == null) { // nothing more to do. - return tuple; + return lossTuple; } Shape weightsShape = sampleWeights.asOutput().shape(); long weightsRank = weightsShape.numDimensions(); if (weightsRank == 0) { // scalar - return new Tuple<>(labels, predictions, sampleWeights); + return new LossTuple<>(labels, predictions, sampleWeights); } if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { @@ -99,7 +99,7 @@ public static Tuple squeezeOrExpandDimensions( } else if (predictionsRank - weightsRank == 1) { sampleWeights = tf.expandDims(sampleWeights, tf.constant(-1L)); } - return new Tuple<>(labels, predictions, sampleWeights); + return new LossTuple<>(labels, predictions, sampleWeights); } // Use dynamic rank. Operand weightsRankTensor = tf.rank(sampleWeights); @@ -109,7 +109,7 @@ public static Tuple squeezeOrExpandDimensions( tf.math.equal(weightsRankTensor, tf.constant(0)), sampleWeights, maybeAdjustWeights(tf, sampleWeights, rankDiff)); - return new Tuple<>(labels, predictions, sampleWeights); + return new LossTuple<>(labels, predictions, sampleWeights); } /** @@ -159,7 +159,7 @@ private static Operand maybeExpandWeights( * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @return labels and predictions, possibly with last dim squeezed. */ - public static Tuple removeSqueezableDimensions( + public static LossTuple removeSqueezableDimensions( Ops tf, Operand labels, Operand predictions) { return removeSqueezableDimensions(tf, labels, predictions, 0); } @@ -174,7 +174,7 @@ public static Tuple removeSqueezableDimensions( * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). * @return labels and predictions, possibly with last dim squeezed. */ - public static Tuple removeSqueezableDimensions( + public static LossTuple removeSqueezableDimensions( Ops tf, Operand labels, Operand predictions, int expectedRankDiff) { tf = tf.withSubScope("removeSqueezableDimensions"); @@ -191,7 +191,7 @@ public static Tuple removeSqueezableDimensions( } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { labels = tf.squeeze(labels); } - return new Tuple<>(labels, predictions); + return new LossTuple<>(labels, predictions); } // Use dynamic rank. @@ -212,7 +212,7 @@ public static Tuple removeSqueezableDimensions( */ labels = tf.squeeze(labels, Squeeze.axis(Collections.singletonList(-1L))); } - return new Tuple<>(labels, predictions); + return new LossTuple<>(labels, predictions); } /** @@ -231,7 +231,7 @@ public static Operand computeWeightedLoss( if (sampleWeight == null) { sampleWeight = tf.dtypes.cast(tf.constant(1), dataType); } - Tuple result = squeezeOrExpandDimensions(tf, null, loss, sampleWeight); + LossTuple result = squeezeOrExpandDimensions(tf, null, loss, sampleWeight); loss = result.getTarget(); sampleWeight = result.getSampleWeights(); From 642069c34d9e6b6c3df92cab4672c315029555de Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 11 Oct 2020 15:29:37 -0400 Subject: [PATCH 06/26] Repair JavaDOx --- .../src/main/java/org/tensorflow/framework/losses/Losses.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 36c04fb2df4..258c761af07 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -97,7 +97,7 @@ public static Operand meanAbsolutePerc } /** - * Calculates the mean squared logarithmic percentage error between labels and predictions. + * Calculates the mean squared logarithmic error between labels and predictions. * *

loss = reduceMean(square(log(labels + 1) - log(predictions + 1))) * @@ -111,7 +111,7 @@ public static Operand meanAbsolutePerc public static Operand meanSquaredLogarithmicError( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Operand tLabels = tf.dtypes.cast(labels, dataType); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); From 249b65194bb055decf02d61f56378e7771e6d05f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 11 Oct 2020 15:30:19 -0400 Subject: [PATCH 07/26] Fixed AllAxis to hanlde dynamic shape when static shape rank is unknown. --- .../framework/losses/impl/LossesImpl.java | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index e483a305be0..9cc77b504c6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -29,8 +29,8 @@ public class LossesImpl { * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @return LossTuple of prediction, label,sampleWeight will be - * null. Each of them possibly has the last dimension squeezed, sampleWeight + * @return LossTuple of prediction, label,sampleWeight will + * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, * label) is returned. */ @@ -64,7 +64,6 @@ public static LossTuple squeezeOrExpandDimensions( public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Shape predictionsShape = predictions.asOutput().shape(); long predictionsRank = predictionsShape.numDimensions(); @@ -183,7 +182,7 @@ public static LossTuple removeSqueezableDimensions( Shape labelsShape = labels.asOutput().shape(); int labelsRank = labelsShape.numDimensions(); - if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) { // Use static rank. int rankDiff = predictionsRank - labelsRank; if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { @@ -290,23 +289,15 @@ public static Operand safeMean( * @return a Constant that represents all the axes of the operand. */ public static Operand allAxis(Ops tf, Operand op) { - int[] ranks = allAxis(op); - return tf.constant(ranks); - } - - /** - * Gets an integer array representing all the axes of the operand. - * - * @param op the Operand - * @param the type of Operand - * @return the integer array representing all the axes of the operand. - */ - private static int[] allAxis(Operand op) { int rank = op.asOutput().shape().numDimensions(); - int[] axes = new int[rank]; - for (int i = 0; i < rank; i++) { - axes[i] = i; + if (rank != Shape.UNKNOWN_SIZE) { + int[] axes = new int[rank]; + for (int i = 0; i < rank; i++) { + axes[i] = i; + } + return tf.constant(axes); + } else { + return tf.range(tf.constant(0), tf.rank(op), tf.constant(1)); } - return axes; } } From 794cfdca096223e521c8c45138fc872cc2a3ec75 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 11 Oct 2020 15:36:18 -0400 Subject: [PATCH 08/26] change method name allAxis to allAxes --- .../org/tensorflow/framework/losses/impl/LossesImpl.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 9cc77b504c6..4413036e4df 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -255,7 +255,7 @@ private static Operand reduceWeightedLoss( loss = weightedLoss; } else { loss = - tf.reduceSum(weightedLoss, allAxis(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); + tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { loss = safeMean(tf, loss, weightedLoss.asOutput().shape().size()); } @@ -275,7 +275,7 @@ private static Operand reduceWeightedLoss( */ public static Operand safeMean( Ops tf, Operand losses, long numElements) { - Operand totalLoss = tf.reduceSum(losses, allAxis(tf, losses)); + Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); return tf.math.divNoNan( totalLoss, tf.dtypes.cast(tf.constant(numElements), losses.asOutput().dataType())); } @@ -288,7 +288,7 @@ public static Operand safeMean( * @param the type of Operand * @return a Constant that represents all the axes of the operand. */ - public static Operand allAxis(Ops tf, Operand op) { + public static Operand allAxes(Ops tf, Operand op) { int rank = op.asOutput().shape().numDimensions(); if (rank != Shape.UNKNOWN_SIZE) { int[] axes = new int[rank]; From fb26c59f40f45836c62f7e0421949cb5bd8e3e3c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 13 Oct 2020 13:12:40 -0400 Subject: [PATCH 09/26] change private method binaryCrossentropy to binaryCrossentropyHelper --- .../main/java/org/tensorflow/framework/losses/Losses.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 258c761af07..cb6baa8c8f9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -150,12 +150,12 @@ public static Operand binaryCrossentro if (labelSmoothing != 0.0f) { tLabels = smoothLabelsBinaryX(tf, tLabels, labelSmoothing); } - Operand bce = binaryCrossentropy(tf, tLabels, predictions, fromLogits); + Operand bce = binaryCrossentropyHelper(tf, tLabels, predictions, fromLogits); return tf.math.mean(bce, tf.constant(-1)); } /** - * Compute binary crossentropy loss between labels and predictions. + * Computes the unreduced crossentropy loss between labels and predictions. * * @param tf the TensorFlow Ops * @param target the target Operand @@ -165,7 +165,7 @@ public static Operand binaryCrossentro * @param the data type of the Operands * @return the binary crossentropy loss. */ - private static Operand binaryCrossentropy( + private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { if (fromLogits) { return tf.nn.sigmoidCrossEntropyWithLogits(target, output); From 928ef066f8d250b4ae41799eea40ab03fe3ecd23 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 13 Oct 2020 15:25:17 -0400 Subject: [PATCH 10/26] Fixed squeezeOrExpandDimensions to make sure the updated labels, predictions and weights are returned in LossTuple --- .../org/tensorflow/framework/losses/impl/LossesImpl.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 4413036e4df..4a276d6804f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -88,7 +88,7 @@ public static LossTuple squeezeOrExpandDimensions( Shape weightsShape = sampleWeights.asOutput().shape(); long weightsRank = weightsShape.numDimensions(); if (weightsRank == 0) { // scalar - return new LossTuple<>(labels, predictions, sampleWeights); + return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights); } if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { @@ -98,7 +98,7 @@ public static LossTuple squeezeOrExpandDimensions( } else if (predictionsRank - weightsRank == 1) { sampleWeights = tf.expandDims(sampleWeights, tf.constant(-1L)); } - return new LossTuple<>(labels, predictions, sampleWeights); + return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights); } // Use dynamic rank. Operand weightsRankTensor = tf.rank(sampleWeights); @@ -108,7 +108,7 @@ public static LossTuple squeezeOrExpandDimensions( tf.math.equal(weightsRankTensor, tf.constant(0)), sampleWeights, maybeAdjustWeights(tf, sampleWeights, rankDiff)); - return new LossTuple<>(labels, predictions, sampleWeights); + return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights); } /** From 2bc54dd821b01c368914efdae87e503c3a61d989 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 27 Oct 2020 12:24:22 -0400 Subject: [PATCH 11/26] Fix JavaDoc, Add in rangeCheck and valueCheck Misc fixes based on review --- .../framework/losses/BinaryCrossentropy.java | 62 ++++-- .../losses/CategoricalCrossentropy.java | 61 ++++-- .../framework/losses/CategoricalHinge.java | 2 +- .../framework/losses/CosineSimilarity.java | 11 +- .../tensorflow/framework/losses/Hinge.java | 40 +++- .../tensorflow/framework/losses/Huber.java | 4 +- .../framework/losses/KLDivergence.java | 2 +- .../tensorflow/framework/losses/LogCosh.java | 6 +- .../org/tensorflow/framework/losses/Loss.java | 12 +- .../tensorflow/framework/losses/Losses.java | 75 +++++-- .../framework/losses/MeanAbsoluteError.java | 2 +- .../losses/MeanAbsolutePercentageError.java | 2 +- .../framework/losses/MeanSquaredError.java | 2 +- .../losses/MeanSquaredLogarithmicError.java | 2 +- .../tensorflow/framework/losses/Poisson.java | 4 +- .../losses/SparseCategoricalCrossentropy.java | 47 +++- .../framework/losses/SquaredHinge.java | 37 +++- .../framework/losses/impl/LossesImpl.java | 99 +++++++++ .../losses/BinaryCrossentropyTest.java | 43 +++- .../losses/CategoricalCrossentropyTest.java | 123 +++++++---- .../framework/losses/HingeTest.java | 205 ++++++++++-------- .../SparseCategoricalCrossentropyTest.java | 29 +++ .../framework/losses/SquaredHingeTest.java | 27 +++ 23 files changed, 673 insertions(+), 224 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index c8be0463403..d194f0843dc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -46,12 +46,10 @@ * Operand<TFloat32> result = bce.call(labels, predictions); * // produces [0.916f, 0.714f] *

- * */ public class BinaryCrossentropy extends Loss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; - public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; private final boolean fromLogits; private final float labelSmoothing; @@ -59,9 +57,7 @@ public class BinaryCrossentropy extends Loss { /** * Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link * #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a - * Loss Reduction of {@link Reduction#AUTO} - * - * + * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -82,7 +78,8 @@ public BinaryCrossentropy(Ops tf, Reduction reduction) { /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, - * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link #REDUCTION_DEFAULT}, + * labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT}, a reduction of {@link + * Loss#REDUCTION_DEFAULT}, * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -93,7 +90,7 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits) { /** * Creates a Binary Crossentropy loss using labelSmoothing of {@link #LABEL_SMOOTHING_DEFAULT} a - * reduction of {@link #REDUCTION_DEFAULT}. + * reduction of {@link Loss#REDUCTION_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of the loss @@ -105,7 +102,7 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a Binary Crossentropy loss using using {@link Class#getSimpleName()} as the loss name, - * and a reduction of {@link #REDUCTION_DEFAULT}. + * and a reduction of {@link Loss#REDUCTION_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -119,7 +116,7 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { } /** - * Creates a Binary Crossentropy loss using a reduction of {@link #REDUCTION_DEFAULT}. + * Creates a Binary Crossentropy loss using a reduction of {@link Loss#REDUCTION_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of the loss @@ -144,9 +141,8 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSm * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. */ - public BinaryCrossentropy( - Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { - this(tf, null, fromLogits, labelSmoothing, reduction); + public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Reduction reduction) { + this(tf, null, fromLogits, labelSmoothing, reduction); } /** @@ -160,20 +156,58 @@ public BinaryCrossentropy( * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. + * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public BinaryCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { super(tf, name, reduction); + if(labelSmoothing < 0 || labelSmoothing > 1) + throw new IllegalArgumentException("labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); this.fromLogits = fromLogits; this.labelSmoothing = labelSmoothing; } - /** {@inheritDoc} */ + /** + * Generates an Operand that calculates the loss. + * + * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} + * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call + * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + * + * @param labels the truth values or labels + * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. + * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sample_weight vector. If the shape of sample_weight is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. + * @return the loss + * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. + */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; + if (!fromLogits) { + // add predictions range check for 0 - 1 + lPredictions = + LossesImpl.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + getTF().dtypes.cast(getTF().constant(0), predictions.asOutput().dataType()), + getTF().dtypes.cast(getTF().constant(1), predictions.asOutput().dataType())); + + } else { + lPredictions = predictions; + } + Operand losses = - Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing); return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index a7491285a68..1550042d8b5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -53,7 +53,6 @@ public class CategoricalCrossentropy extends Loss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; - public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; public static final int DEFAULT_AXIS = -1; private final boolean fromLogits; @@ -63,7 +62,7 @@ public class CategoricalCrossentropy extends Loss { /** * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, * {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Reduction#AUTO}, and an axis of {@link + * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and an axis of {@link * #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -75,7 +74,7 @@ public CategoricalCrossentropy(Ops tf) { /** * Creates a categorical cross entropy Loss using {@link #FROM_LOGITS_DEFAULT} for fromLogits, * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Reduction#AUTO}, and an axis of {@link #DEFAULT_AXIS} + * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss @@ -111,7 +110,7 @@ public CategoricalCrossentropy(Ops tf, String name, Reduction reduction) { /** * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, * {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing, a Loss Reduction of {@link - * Reduction#AUTO}, and an axis of {@link #DEFAULT_AXIS} + * Loss#REDUCTION_DEFAULT}, and an axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -122,7 +121,7 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits) { /** * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Reduction#AUTO}, and a channel axis of {@link + * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of {@link * #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -135,21 +134,20 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, - * a Loss Reduction of {@link Reduction#AUTO}, and a channel axis of {@link #DEFAULT_AXIS} + * a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the - * loss between the predicted labels and a smoothed version of the true labels, where the - * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to - * heavier smoothing. + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Reduction#AUTO}, + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, * and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops @@ -194,6 +192,7 @@ public CategoricalCrossentropy( * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' * and axis=1 corresponds to data format 'Channels First'. + * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( Ops tf, @@ -203,17 +202,53 @@ public CategoricalCrossentropy( Reduction reduction, int axis) { super(tf, name, reduction); + if(labelSmoothing < 0 || labelSmoothing > 1) + throw new IllegalArgumentException("labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); this.fromLogits = fromLogits; this.labelSmoothing = labelSmoothing; this.axis = axis; } - /** {@inheritDoc} */ + /** + * Generates an Operand that calculates the loss. + * + * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} + * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call + * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + * + * @param labels the truth values or labels + * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. + * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sample_weight vector. If the shape of sample_weight is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. + * @return the loss + * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. + */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; + if (!fromLogits) { + // add predictions range check for 0 - 1 + lPredictions = + LossesImpl.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + getTF().dtypes.cast(getTF().constant(0), predictions.asOutput().dataType()), + getTF().dtypes.cast(getTF().constant(1), predictions.asOutput().dataType())); + + } else { + lPredictions = predictions; + } Operand losses = - Losses.categoricalCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + Losses.categoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 2b1f1e044f9..6417a6e3673 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -54,7 +54,7 @@ public class CategoricalHinge extends Loss { /** * Creates a Categorical Hinge Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Reduction#AUTO} + * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 26b5fdb6ee9..5d6a882665e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -8,11 +8,12 @@ /** * Computes the cosine similarity between labels and predictions. * - *

Note that it is a negative quantity between -1 and 0, where 0 indicates orthogonality and - * values closer to -1 indicate greater similarity. This makes it usable as a loss function in a - * setting where you try to maximize the proximity between predictions and targets. If either labels - * or predictions is a zero vector, cosine similarity will be 0 regardless of the proximity between - * predictions and targets. + *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 + * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to + * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels or predictions is + * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and + * targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index ad23901d2f1..0bfe3d63b9a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -53,7 +53,7 @@ public class Hinge extends Loss { /** * Creates a Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss Reduction - * of {@link Reduction#AUTO} + * of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -82,11 +82,43 @@ public Hinge(Ops tf, String name, Reduction reduction) { super(tf, name, reduction); } - /** {@inheritDoc} */ + /** + * Generates an Operand that calculates the loss. + * + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the label values are not in the set + * [-1., 0., 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if the + * label values are not in the set [-1., 0., 1.]. + * + * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. + * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sample_weight vector. If the shape of sample_weight is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. + * @return the loss + * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. + */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.hinge(getTF(), labels, predictions); + Operand labels, Operand predictions, Operand sampleWeights) { + Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? + (Operand)labels : + tf.dtypes.cast(labels, predictions.asOutput().dataType()); + tLabels = LossesImpl.valueCheck( + getTF(), + "labels value check [-1, 0, 1]", + tLabels, + getTF().dtypes.cast(getTF().constant(new int[] { -1, 0, 1}), + predictions.asOutput().dataType())); + + Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 261a0439f83..baeb8c97033 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -62,7 +62,7 @@ public class Huber extends Loss { /** * Creates a Huber Loss using {@link Class#getSimpleName()} as the loss name, {@link - * #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link Reduction#AUTO} + * #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -72,7 +72,7 @@ public Huber(Ops tf) { /** * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta and a Loss Reduction of {@link - * Reduction#AUTO} + * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 6c9371ff8fd..80d11203ca4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -53,7 +53,7 @@ public class KLDivergence extends Loss { /** * Creates a Kullback Leibler Divergence Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Reduction#AUTO} + * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index a0e99180ef8..da6992ec776 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -51,7 +51,7 @@ public class LogCosh extends Loss { /** * Creates a LogCosh Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Reduction#AUTO} + * Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -60,7 +60,7 @@ public LogCosh(Ops tf) { } /** - * Creates a LogCosh Loss using a Loss Reduction of {@link Reduction#AUTO} + * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -79,7 +79,7 @@ public LogCosh(Ops tf, Reduction reduction) { } /** - * Creates a Kullback Leibler Divergence Loss + * Creates a LogCosh Loss * * @param tf the TensorFlow Ops * @param name the name of the loss diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 445ebf99565..b56f77e9be0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -5,12 +5,14 @@ import org.tensorflow.types.family.TNumber; public abstract class Loss { + public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; + protected final Ops tf; protected final Reduction reduction; /** * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link - * Reduction#AUTO} + * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -19,10 +21,10 @@ protected Loss(Ops tf) { } /** - * Creates a Loss using a Loss Reduction of {@link Reduction#AUTO} + * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops - * @param name the name of this Loss + * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. */ protected Loss(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -32,7 +34,7 @@ protected Loss(Ops tf, String name) { * Creates a Loss * * @param tf the TensorFlow Ops - * @param name the name of this loss + * @param name the name of this loss, if null the name will be {@link Class#getSimpleName()}. * @param reduction Type of Reduction to apply to the loss. */ protected Loss(Ops tf, String name, Reduction reduction) { @@ -54,7 +56,7 @@ public Operand call(Operand labels, } /** - * Calculates the loss + * Generates an Operand that calculates the loss. * * @param labels the truth values or labels * @param predictions the predictions diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index cb6baa8c8f9..b606cc04f12 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -2,8 +2,8 @@ import org.tensorflow.DataType; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceAll; @@ -82,7 +82,7 @@ public static Operand meanSquaredError public static Operand meanAbsolutePercentageError( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels,dataType); + Operand tLabels = tf.dtypes.cast(labels, dataType); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -171,16 +171,19 @@ private static Operand binaryCrossentropyHelper( return tf.nn.sigmoidCrossEntropyWithLogits(target, output); } + /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { - // TODO - this does not work, cannot walk back, work around is only go back 1. - // output = backtrackIdentity(output); - if (output.op().type().equals(Sigmoid.OP_NAME)) { - if (output.op().numOutputs() != 1) - throw new IllegalArgumentException("output can only have 1 output"); - output = output.op().output(0); - return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - } + // TODO - this does not work + // TODO output = backtrackIdentity(output); + // TODO if (output.op().type().equals(Sigmoid.OP_NAME)) { + // TODO if (output.op().numInputess() != 1) + // TODO throw new IllegalArgumentException("output can only have 1 output"); + // TODO output = output.op().inout(0); + // TODO return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + // TODO} } + */ + DataType dataType = output.asOutput().dataType(); Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); @@ -205,10 +208,9 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, compute the - * loss between the predicted labels and a smoothed version of the true labels, where the - * smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing correspond to - * heavier smoothing. + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. @@ -232,7 +234,9 @@ public static Operand categoricalCross if (fromLogits) { return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); } + /* TODO if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { + // TODO output = backtrackIdentity(output); doesn't seem to work with Java version. if (predictions.op().type().equals("Softmax")) { if (predictions.op().numOutputs() != 1) @@ -241,6 +245,8 @@ public static Operand categoricalCross return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); } } + */ + Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); @@ -262,7 +268,7 @@ public static Operand categoricalCross * Computes the categorical hinge loss between labels and predictions. * * @param tf the TensorFlow Ops - * @param labels true targets, values are expected to be 0 or 1. + * @param labels true targets, values are expected to be 0 or 1. * @param predictions the predictions * @param the data type of the predictions and labels * @return the categorical hinge loss @@ -495,8 +501,10 @@ public static Operand sparseCategorica Operand one = tf.dtypes.cast(tf.constant(1), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); + /* TODO need ability to walk back inputs if (!fromLogits && !(predictions instanceof Variable) && (!tf.scope().env().isEager())) { // TODO output = backtrackIdentity(output); doesn't seem to work with Java version. + /* TODO if (predictions.op().type().equals(Softmax.OP_NAME)) { // When softmax activation function is used for output operation, we // use logits from the softmax function directly to compute loss in order @@ -506,7 +514,9 @@ public static Operand sparseCategorica // TODO output = output.op.inputs[0] fromLogits = true; } + } + */ if (!fromLogits) { predictions = tf.clipByValue(predictions, epsilonConst, oneMinusEpsilonConst); @@ -531,15 +541,15 @@ public static Operand sparseCategorica boolean updateShape = labelsRank != predictionsRank - 1; if (updateShape) { // TODO check to see if this is right - Shape newShape = labelsShape.take(labelsRank-1); + Shape newShape = labelsShape.take(labelsRank - 1); iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = tf.reshape( predictions, - tf.constant(new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); + tf.constant( + new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } - @SuppressWarnings("unchecked") Operand loss = tf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); if (updateShape && predictionsRank >= 3) { @@ -577,7 +587,27 @@ public static Operand squaredHinge( tf.constant(-1)); } - // private methods + // private methods/** + // * Calculates the loss + // * + // * @param labels the truth values or labels + // * @param predictions the predictions + // * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar + // is + // * provided, then the loss is simply scaled by the given value. If sample_weight is a + // tensor + // * of size [batch_size], then the total loss for each sample of the batch is rescaled by + // the + // * corresponding element in the sample_weight vector. If the shape of sample_weight is + // * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element + // of + // * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all + // loss + // * functions reduce by 1 dimension, usually axis=-1.) + // * @param The data type of the predictions, sampleWeights and loss. + // * @param The data type of the labels. + // * @return the loss + // * /** * Smooths binary labels @@ -604,10 +634,9 @@ private static Operand smoothLabelsBinaryX( * * @param tf the TensorFlow Ops * @param labels true targets - * @param labelSmoothing A number in the range [0, 1]. When 0, no smoothing occurs. When > 0, - * compute the loss between the predicted labels and a smoothed version of the true labels, - * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing - * correspond to heavier smoothing. + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 * @param the data type of the labels * @return the smoothed categorical labels */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 6aa39218ac8..d2a297de4ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -50,7 +50,7 @@ public class MeanAbsoluteError extends Loss { /** * Creates a MeanAbsoluteError Loss using {@link Class#getSimpleName()} as the loss name and a - * Loss Reduction of {@link Reduction#AUTO} + * Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 73bb62e4686..7c5a776e483 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -50,7 +50,7 @@ public class MeanAbsolutePercentageError extends Loss { /** * Creates a MeanAbsolutePercentageError Loss using {@link Class#getSimpleName()} as the loss name - * and a Loss Reduction of {@link Reduction#AUTO} + * and a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 6e5160cd288..5aff273be13 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -50,7 +50,7 @@ public class MeanSquaredError extends Loss { /** * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Reduction#AUTO} + * Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 8b12c2b525e..2efdf56db78 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -50,7 +50,7 @@ public class MeanSquaredLogarithmicError extends Loss { /** * Creates a MeanSquaredError Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Reduction#AUTO} + * Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 66dd356ab33..a7a9fb04609 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -51,7 +51,7 @@ public class Poisson extends Loss { /** * Creates a Poisson Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Reduction#AUTO} + * Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -60,7 +60,7 @@ public Poisson(Ops tf) { } /** - * Creates a Poisson Loss using a Loss Reduction of {@link Reduction#AUTO} + * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 7776e23e40b..7636cb8923d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -58,14 +58,13 @@ public class SparseCategoricalCrossentropy extends Loss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final int AXIS_DEFAULT = -1; - public static final Reduction REDUCTION_DEFAULT = Reduction.AUTO; private final boolean fromLogits; private final int axis; /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Reduction#AUTO}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ @@ -74,7 +73,7 @@ public SparseCategoricalCrossentropy(Ops tf) { } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Reduction#AUTO}, + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, * and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops @@ -108,7 +107,7 @@ public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Reduction#AUTO}, and + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and * fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops @@ -121,7 +120,7 @@ public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Reduction#AUTO} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -159,12 +158,46 @@ public SparseCategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Generates an Operand the calculates the loss. + * + * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} + * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call + * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + * + * @param labels the truth values or labels + * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. + * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sample_weight vector. If the shape of sample_weight is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. + * @return the loss + * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. + */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { + Operand lPredictions; + if (!fromLogits) { + // add predictions range check for 0 - 1 + lPredictions = + LossesImpl.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + getTF().dtypes.cast(getTF().constant(0), predictions.asOutput().dataType()), + getTF().dtypes.cast(getTF().constant(1), predictions.asOutput().dataType())); + + } else { + lPredictions = predictions; + } Operand losses = - Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + Losses.sparseCategoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, axis); return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 611b93a1a98..9f0b75bf78b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -54,7 +54,7 @@ public class SquaredHinge extends Loss { /** * Creates a Squared Hinge Loss using {@link Class#getSimpleName()} as the loss name and a Loss - * Reduction of {@link Reduction#AUTO} + * Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops */ @@ -83,11 +83,42 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { super(tf, name, reduction); } - /** {@inheritDoc} */ + /** + * Generates an Operand that calculates the loss. + * + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the label values are not in the set + * [-1., 0., 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if the + * label values are not in the set [-1., 0., 1.]. + * + * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. + * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sample_weight vector. If the shape of sample_weight is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. + * @return the loss + * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. + */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand losses = Losses.squaredHinge(getTF(), labels, predictions); + Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? + (Operand)labels : + tf.dtypes.cast(labels, predictions.asOutput().dataType()); + tLabels = LossesImpl.valueCheck( + getTF(), + "labels value check [-1, 0, 1]", + tLabels, + getTF().dtypes.cast(getTF().constant(new int[] { -1, 0, 1}), + predictions.asOutput().dataType())); + Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 4a276d6804f..d77f513bb06 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -5,13 +5,21 @@ import org.tensorflow.framework.losses.Reduction; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.AssertThat; import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.SetDiff1d; import org.tensorflow.op.core.Squeeze; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; +import java.util.Arrays; import java.util.Collections; +/** + * These are helper methods for Losses and will be module private when Java modularity is applied to + * TensorFlow Java. These methods should not be used outside of the Loss package. + */ public class LossesImpl { /** @@ -300,4 +308,95 @@ public static Operand allAxes(Ops tf, Operand op) return tf.range(tf.constant(0), tf.rank(op), tf.constant(1)); } } + + /** + * Perform an inclusive range check on the values + * + * @param tf the TensorFlow Ops + * @param prefix A String prefix to include in the error message + * @param values the values to check + * @param minValue the minimum value + * @param maxValue the maximum value + * @param the datatype for the values + * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph + * Session + * @throws IllegalArgumentException if the TensorFlow Ops represents an Eager Session + */ + public static Operand rangeCheck( + Ops tf, String prefix, Operand values, Operand minValue, Operand maxValue) { + Operand allDims = allAxes(tf, values); + Operand cond = + tf.math.logicalAnd( + tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims), + tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims)); + // Graph and Eager mode need to be handled differently, control dependencies are not allowed in + // Eager mode + if (tf.scope().env().isGraph()) { + AssertThat assertThat = + tf.assertThat( + cond, + Arrays.asList( + tf.constant(prefix), + tf.constant(": values out of range, "), + tf.constant("minimum = "), + minValue, + tf.constant(", maximum = "), + maxValue)); + Ops ltf = + tf.withSubScope("rangeCheck") + .withControlDependencies(Collections.singletonList(assertThat)); + return ltf.identity(values); + } else if (!cond.asOutput().data().getBoolean()) + throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); + else return values; + } + + /** + * Checks to see if all the values are in the allowed values set. Running the operand in Graph + * mode will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException}, if at least one + * value is not in the allowed values set. In Eager mode, this method will throw an {@link + * IllegalArgumentException} if at least one value is not in the allowed values set. + * + * @param tf The TensorFlow Ops + * @param prefix A String prefix to include in the error message + * @param values the values to check + * @param allowedValues the allowed values + * @param the data type for values and allowed values + * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph + * Session + * @throws IllegalArgumentException if the Session is in Eager mode and at least one value is not + * in the allowed values set + */ + public static Operand valueCheck( + Ops tf, String prefix, Operand values, Operand allowedValues) { + Operand flatValues = + tf.reshape(values, tf.constant(Shape.of(values.asOutput().shape().size()))); + SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.DTYPE); + long diffSize = diff.out().asOutput().shape().size(); + + if (diffSize != Shape.UNKNOWN_SIZE) { + if (diffSize != 0) { // at least 1 value in the diff did not match the allowed values. + throw new IllegalArgumentException(String.format("%s : values not in value set,", prefix)); + } else return values; + } else { // use dynamic shape + Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); + // Graph and Eager mode need to be handled differently, control dependencies are not allowed + // in Eager mode + if (tf.scope().env().isGraph()) { + AssertThat assertThat = + tf.assertThat( + cond, + Arrays.asList( + tf.constant(prefix), + tf.constant(": values not in value set, values = "), + values)); + Ops ltf = + tf.withSubScope("valueCheck") + .withControlDependencies(Collections.singletonList(assertThat)); + return ltf.identity(values); + } else if (!cond.asOutput().data().getBoolean()) + throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); + else return values; + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index 83f474f3f88..86401f03f5d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -7,6 +7,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class BinaryCrossentropyTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -30,7 +32,8 @@ public void testAllCorrectUnweighted() { -100.0f, 100.0f, -100.0f, -100.0f, -100.0f, 100.0f }; - Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new BinaryCrossentropy(tf, true); loss = instance.call(yTrue, logits); @@ -38,6 +41,32 @@ public void testAllCorrectUnweighted() { } } + @Test + public void testInvalidPredictionsRange() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Class catchClass = + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; + assertThrows( + catchClass, + () -> { + Ops tf = testSession.getTF(); + BinaryCrossentropy instance = new BinaryCrossentropy(tf); + float[] trueArray = {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}; + float[] predArray = {2f, 1f, -1f, 0f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); + } + } + /** Test of call method, of class BinaryCrossentropy. */ @Test public void testUnweighted() { @@ -60,7 +89,8 @@ public void testUnweighted() { 100.0f, 100.0f, -100.0f }; Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); - Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); instance = new BinaryCrossentropy(tf, true); loss = instance.call(yTrue1, logits); expected = 33.33333f; @@ -91,7 +121,8 @@ public void testScalarWeighted() { 100.0f, 100.0f, -100.0f }; Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); - Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); instance = new BinaryCrossentropy(tf, true); loss = instance.call(yTrue1, logits, sampleWeight); expected = 76.66667f; @@ -124,7 +155,8 @@ public void testSampleWeighted() { }; float[] sampleWeightArray1 = {4f, 3f}; Operand yTrue1 = tf.reshape(tf.constant(trueArray1), tf.constant(Shape.of(2, 3))); - Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight1 = tf.constant(sampleWeightArray1); instance = new BinaryCrossentropy(tf, true); loss = instance.call(yTrue1, logits, sampleWeight1); @@ -146,7 +178,8 @@ public void testNoReduction() { 100.0f, 100.0f, -100.0f }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand logits = + tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); BinaryCrossentropy instance = new BinaryCrossentropy( tf, true, BinaryCrossentropy.LABEL_SMOOTHING_DEFAULT, Reduction.NONE); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index f1bf8f0b2be..0e273fd6a8c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -6,9 +6,12 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class CategoricalCrossentropyTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -26,30 +29,64 @@ public void testAllCorrectUnweighted() { 0L, 0L, 1L }; float[] predArray = { - 1.f, 0.f, 0.f, - 0.f, 1.f, 0.f, - 0.f, 0.f, 1.F + 1.F, 0.F, 0.F, + 0.F, 1.F, 0.F, + 0.F, 0.F, 1.F }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); Operand loss = instance.call(yTrue, yPred); - float expected = 0f; + float expected = 0F; testSession.evaluate(expected, loss); // Test with logits. float[] logitsArray = { - 10.f, 0.f, 0.f, - 0.f, 10.f, 0.f, - 0.f, 0.f, 10.F + 10.F, 0.F, 0.F, + 0.F, 10.F, 0.F, + 0.F, 0.F, 10.F }; yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(tf, true); loss = instance.call(yTrue, logits); - testSession.setEpsilon(1e-3f); - testSession.evaluate(0.0f, loss); + testSession.setEpsilon(1e-3F); + testSession.evaluate(0.0F, loss); + } + } + + @Test + public void testInvalidPredictionsRange() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Class catchClass = + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; + assertThrows( + catchClass, + () -> { + Ops tf = testSession.getTF(); + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + float[] trueArray = { + 1L, 0L, 0L, + 0L, 1L, 0L, + 0L, 0L, 1L + }; + float[] predArray = { + -1.F, 0.F, 0.F, + 0.F, 1.F, 0.F, + 0.F, 0.F, 1.F + }; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); } } @@ -62,27 +99,27 @@ public void testUnweighted() { CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; float[] predArray = { - .9f, .05f, .05f, - .5f, .89f, .6f, - .05f, .01f, .94f + .9F, .05F, .05F, + .5F, .89F, .6F, + .05F, .01F, .94F }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand loss = instance.call(yTrue, yPred); - float expected = 0.32396814f; + float expected = 0.32396814F; testSession.evaluate(expected, loss); // Test with logits. float[] logitsArray = { - 8.f, 1.f, 1.f, - 0.f, 9.f, 1.f, - 2.f, 3.f, 5.F + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(tf, true); loss = instance.call(yTrue, logits); - expected = 0.0573755f; + expected = 0.0573755F; testSession.evaluate(expected, loss); } } @@ -100,30 +137,30 @@ public void testScalarWeighted() { 0, 0, 1 }; float[] predArray = { - .9f, .05f, .05f, - .5f, .89f, .6f, - .05f, .01f, .94f + .9F, .05F, .05F, + .5F, .89F, .6F, + .05F, .01F, .94F }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand sampleWeight = tf.constant(2.3f); + Operand sampleWeight = tf.constant(2.3F); CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = .7451267f; + float expected = .7451267F; testSession.evaluate(expected, loss); // Test with logits. float[] logitsArray = { - 8.f, 1.f, 1.f, - 0.f, 9.f, 1.f, - 2.f, 3.f, 5.F + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(tf, true); loss = instance.call(yTrue, logits, sampleWeight); - expected = 0.13196386f; + expected = 0.13196386F; testSession.evaluate(expected, loss); } } @@ -134,36 +171,36 @@ public void testSsampleWeighted() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - float[] sampeWeightArray = {1.2f, 3.4f, 5.6f}; + float[] sampeWeightArray = {1.2F, 3.4F, 5.6F}; int[] trueArray = { 1, 0, 0, 0, 1, 0, 0, 0, 1 }; float[] predArray = { - .9f, .05f, .05f, - .5f, .89f, .6f, - .05f, .01f, .94f + .9F, .05F, .05F, + .5F, .89F, .6F, + .05F, .01F, .94F }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.reshape(tf.constant(sampeWeightArray), tf.constant(Shape.of(3, 1))); Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 1.0696f; + float expected = 1.0696F; testSession.evaluate(expected, loss); // Test with logits. float[] logitsArray = { - 8.f, 1.f, 1.f, - 0.f, 9.f, 1.f, - 2.f, 3.f, 5.F + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F }; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(tf, true); loss = instance.call(yTrue, logits, sampleWeight); - expected = 0.31829f; + expected = 0.31829F; testSession.evaluate(expected, loss); } } @@ -177,17 +214,17 @@ public void testNoReduction() { // Test with logits. int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; float[] logitsArray = { - 8.f, 1.f, 1.f, - 0.f, 9.f, 1.f, - 2.f, 3.f, 5.F + 8.F, 1.F, 1.F, + 0.F, 9.F, 1.F, + 2.F, 3.F, 5.F }; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); CategoricalCrossentropy instance = - new CategoricalCrossentropy(tf, true, 0.0f, Reduction.NONE); + new CategoricalCrossentropy(tf, true, 0.0F, Reduction.NONE); Operand loss = instance.call(yTrue, logits); - Float[] expected = {0.001822f, 0.000459f, 0.169846f}; + Float[] expected = {0.001822F, 0.000459F, 0.169846F}; testSession.evaluate(expected, loss); } } @@ -197,16 +234,16 @@ public void testLabelSmoothing() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - float labelSmoothing = 0.1f; + float labelSmoothing = 0.1F; int[] trueArray = {1, 0, 0}; - float[] logitsArray = {100.0f, -100.0f, -100.0f}; + float[] logitsArray = {100.0F, -100.0F, -100.0F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(1, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(1, 3))); CategoricalCrossentropy instance = new CategoricalCrossentropy(tf, true, labelSmoothing); Operand loss = instance.call(yTrue, logits); - float expected = 400.0f * labelSmoothing / 3.0f; + float expected = 400.0F * labelSmoothing / 3.0F; testSession.evaluate(expected, loss); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index 1f13d0392d7..8951883a0a3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -7,102 +7,129 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class HingeTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class Hinge. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.50625f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class Hinge. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); - float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); - float expected = 0.50625f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testInvalidLabelValue() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Class catchClass = + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; + assertThrows( + catchClass, + () -> { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {2f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); + } + } - /** - * Test of call method, of class Hinge. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); - float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 1.164375f; - testSession.evaluate(expected, loss); + /** Test of call method, of class Hinge. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.164375f; + testSession.evaluate(expected, loss); - // todo Verify we get the same output when the same input is given - } - } + // todo Verify we get the same output when the same input is given + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); - float[] sampleArray = {1.2f, 3.4f}; - float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 1.06125f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] sampleArray = {1.2f, 3.4f}; + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.06125f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf); - float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf); + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Hinge instance = new Hinge(tf, Reduction.AUTO); - float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; - float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Hinge instance = new Hinge(tf, Reduction.AUTO); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; + float[] trueArray = {0f, 1f, 0f, 1f, 0f, 0f, 1f, 1f}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 2.0125f; - testSession.evaluate(expected, loss); - } - } + float expected = 2.0125f; + testSession.evaluate(expected, loss); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index 89b96bad198..75109310a5a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -9,6 +9,8 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class SparseCategoricalCrossentropyTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -47,6 +49,33 @@ public void testAllCorrectUnweighted() { } } + @Test + public void testInvalidPredictionsRange() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Class catchClass = + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; + assertThrows( + catchClass, + () -> { + Ops tf = testSession.getTF(); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + int[] trueArray = {0, 1, 2}; + float[] predArray = { + 1.9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); + } + } + /** Test of call method, of class SparseCategoricalCrossentropy. */ @Test public void testUnweighted() { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index 19236f50749..ca5bd3eb759 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -7,6 +7,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class SquaredHingeTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -29,6 +31,31 @@ public void testUnweighted() { } } + @Test + public void testInvalidLabelValue() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Class catchClass = + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; + assertThrows( + catchClass, + () -> { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 2, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); + } + } + /** * Test of call method, of class SquaredHinge. */ From 951443b6cba9e42911ca2cfae05bee920d5ff229 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 27 Oct 2020 12:31:14 -0400 Subject: [PATCH 12/26] Fix unused imports and add @SuppressWarnings("unchecked") for casts. --- .../main/java/org/tensorflow/framework/losses/Hinge.java | 1 + .../main/java/org/tensorflow/framework/losses/Losses.java | 7 ++----- .../main/java/org/tensorflow/framework/losses/Poisson.java | 1 - .../java/org/tensorflow/framework/losses/SquaredHinge.java | 1 + 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 0bfe3d63b9a..8e45423e20f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -108,6 +108,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { + @SuppressWarnings("unchecked") Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? (Operand)labels : tf.dtypes.cast(labels, predictions.asOutput().dataType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index b606cc04f12..b55550b904a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -9,11 +9,8 @@ import org.tensorflow.op.core.ReduceAll; import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; -import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; -import org.tensorflow.op.math.Sigmoid; import org.tensorflow.op.math.Softplus; -import org.tensorflow.op.nn.Softmax; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; @@ -167,9 +164,9 @@ public static Operand binaryCrossentro */ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { - if (fromLogits) { + if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - } + /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index a7a9fb04609..4c724fbfc5b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -4,7 +4,6 @@ import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; /** * Computes the Poisson loss between labels and predictions. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 9f0b75bf78b..e7e140fcd38 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -109,6 +109,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { + @SuppressWarnings("unchecked") Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? (Operand)labels : tf.dtypes.cast(labels, predictions.asOutput().dataType()); From ebac9e84264db5b1ee101c6cd1b4966a77b9756f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 29 Oct 2020 13:54:49 -0400 Subject: [PATCH 13/26] Add copyright --- .../framework/losses/BinaryCrossentropy.java | 15 +++++++++++++++ .../framework/losses/CategoricalCrossentropy.java | 15 +++++++++++++++ .../framework/losses/CategoricalHinge.java | 15 +++++++++++++++ .../framework/losses/CosineSimilarity.java | 15 +++++++++++++++ .../org/tensorflow/framework/losses/Hinge.java | 15 +++++++++++++++ .../org/tensorflow/framework/losses/Huber.java | 15 +++++++++++++++ .../tensorflow/framework/losses/KLDivergence.java | 15 +++++++++++++++ .../org/tensorflow/framework/losses/LogCosh.java | 15 +++++++++++++++ .../org/tensorflow/framework/losses/Loss.java | 15 +++++++++++++++ .../org/tensorflow/framework/losses/Losses.java | 15 +++++++++++++++ .../framework/losses/MeanAbsoluteError.java | 15 +++++++++++++++ .../losses/MeanAbsolutePercentageError.java | 15 +++++++++++++++ .../framework/losses/MeanSquaredError.java | 15 +++++++++++++++ .../losses/MeanSquaredLogarithmicError.java | 15 +++++++++++++++ .../org/tensorflow/framework/losses/Poisson.java | 15 +++++++++++++++ .../tensorflow/framework/losses/Reduction.java | 15 +++++++++++++++ .../losses/SparseCategoricalCrossentropy.java | 15 +++++++++++++++ .../tensorflow/framework/losses/SquaredHinge.java | 15 +++++++++++++++ .../framework/losses/impl/LossTuple.java | 15 +++++++++++++++ .../framework/losses/impl/LossesImpl.java | 15 +++++++++++++++ 20 files changed, 300 insertions(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index d194f0843dc..a4d89c1019c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 1550042d8b5..5f265f0eabd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 6417a6e3673..def66715869 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 5d6a882665e..0813148e17a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 8e45423e20f..a1e6fe08443 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index baeb8c97033..d4aef1bb2e7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 80d11203ca4..6433c753c7c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index da6992ec776..c573ebe9192 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index b56f77e9be0..4b1feb03d16 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index b55550b904a..02c8d879464 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.DataType; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index d2a297de4ed..c4062330ea0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 7c5a776e483..0e8d1c76a39 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 5aff273be13..fe06ba6f853 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 2efdf56db78..c2b48a0a111 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 4c724fbfc5b..273b80824b1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java index 1e4573118c5..26e151c1e81 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 7636cb8923d..115831cceff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index e7e140fcd38..8e091015535 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 596fb31c0d5..76ce6ff4d43 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses.impl; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index d77f513bb06..9ef2522b6a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses.impl; import org.tensorflow.DataType; From d8f3254e7bf8e0eef7a8b715c805f9d378bc10ba Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 29 Oct 2020 14:04:39 -0400 Subject: [PATCH 14/26] Add CastHelper and used that for all casts --- .../framework/losses/BinaryCrossentropy.java | 6 +- .../losses/CategoricalCrossentropy.java | 5 +- .../tensorflow/framework/losses/Hinge.java | 5 +- .../tensorflow/framework/losses/Losses.java | 96 ++++++++++--------- .../losses/SparseCategoricalCrossentropy.java | 5 +- .../framework/losses/SquaredHinge.java | 5 +- .../framework/losses/impl/LossesImpl.java | 10 +- .../framework/utils/CastHelper.java | 43 +++++++++ 8 files changed, 114 insertions(+), 61 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index a4d89c1019c..26ddce312e6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -20,6 +20,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Computes the cross-entropy loss between true labels and predicted labels. * @@ -214,8 +216,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - getTF().dtypes.cast(getTF().constant(0), predictions.asOutput().dataType()), - getTF().dtypes.cast(getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), + cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 5f265f0eabd..5728e20fab2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the crossentropy loss between the labels and predictions. @@ -256,8 +257,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - getTF().dtypes.cast(getTF().constant(0), predictions.asOutput().dataType()), - getTF().dtypes.cast(getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), + cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index a1e6fe08443..b74c2fe3bd8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. @@ -126,12 +127,12 @@ public Operand call( @SuppressWarnings("unchecked") Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? (Operand)labels : - tf.dtypes.cast(labels, predictions.asOutput().dataType()); + cast(tf, labels, predictions.asOutput().dataType()); tLabels = LossesImpl.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - getTF().dtypes.cast(getTF().constant(new int[] { -1, 0, 1}), + cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.asOutput().dataType())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 02c8d879464..ff0b513c4af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -30,6 +30,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** Built-in loss functions. */ public class Losses { @@ -50,7 +52,7 @@ public class Losses { */ public static Operand meanAbsoluteError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -72,7 +74,7 @@ public static Operand meanAbsoluteErro */ public static Operand meanSquaredError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -94,7 +96,7 @@ public static Operand meanSquaredError public static Operand meanAbsolutePercentageError( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -103,9 +105,9 @@ public static Operand meanAbsolutePerc tf.math.div( tf.math.sub(tLabels, predictions), tf.math.maximum( - tf.math.abs(tLabels), tf.dtypes.cast(tf.constant(EPSILON), dataType)))); + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), dataType)))); return tf.math.mul( - tf.dtypes.cast(tf.constant(100), dataType), tf.math.mean(diff, tf.constant(-1))); + cast(tf, tf.constant(100), dataType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -123,13 +125,13 @@ public static Operand meanAbsolutePerc public static Operand meanSquaredLogarithmicError( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); - Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); Operand firstLog = tf.math.log(tf.math.add(tf.math.maximum(predictions, epsilonConst), one)); Operand secondLog = tf.math.log(tf.math.add(tf.math.maximum(tLabels, epsilonConst), one)); @@ -154,7 +156,7 @@ public static Operand meanSquaredLogar public static Operand binaryCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -197,8 +199,8 @@ private static Operand binaryCrossentropyHelper( */ DataType dataType = output.asOutput().dataType(); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); output = tf.clipByValue(output, epsilonConst, oneMinusEpsilonConst); @@ -235,7 +237,7 @@ public static Operand categoricalCross float labelSmoothing, int axis) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -259,8 +261,8 @@ public static Operand categoricalCross } */ - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); predictions = tf.math.div( @@ -288,12 +290,12 @@ public static Operand categoricalCross public static Operand categoricalHinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); Operand pos = tf.reduceSum( @@ -330,7 +332,7 @@ public static Operand categoricalHinge public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int axis) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); @@ -357,12 +359,12 @@ public static Operand cosineSimilarity public static Operand hinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); tLabels = maybeConvertLabels(tf, tLabels); @@ -393,14 +395,14 @@ public static Operand hinge( public static Operand huber( Ops tf, Operand labels, Operand predictions, float delta) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand error = tf.math.sub(predictions, tLabels); - Operand deltaConst = tf.dtypes.cast(tf.constant(delta), dataType); - Operand point5 = tf.dtypes.cast(tf.constant(0.5), dataType); + Operand deltaConst = cast(tf, tf.constant(delta), dataType); + Operand point5 = cast(tf, tf.constant(0.5), dataType); Operand absError = tf.math.abs(error); Operand quadratic = tf.math.minimum(absError, deltaConst); Operand linear = tf.math.sub(absError, quadratic); @@ -424,12 +426,12 @@ public static Operand huber( public static Operand kullbackLeiblerDivergence( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); tLabels = tf.clipByValue(tLabels, epsilonConst, one); predictions = tf.clipByValue(predictions, epsilonConst, one); @@ -454,12 +456,12 @@ public static Operand kullbackLeiblerD public static Operand logCosh( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand minusTwo = tf.dtypes.cast(tf.constant(-2), dataType); - Operand two = tf.dtypes.cast(tf.constant(2), dataType); + Operand minusTwo = cast(tf, tf.constant(-2), dataType); + Operand two = cast(tf, tf.constant(2), dataType); Operand diff = tf.math.sub(predictions, tLabels); Softplus softplus = tf.math.softplus(tf.math.mul(minusTwo, diff)); @@ -482,11 +484,11 @@ public static Operand logCosh( public static Operand poisson( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); return tf.math.mean( tf.math.sub( @@ -509,8 +511,8 @@ public static Operand poisson( public static Operand sparseCategoricalCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { DataType dataType = predictions.asOutput().dataType(); - Operand epsilonConst = tf.dtypes.cast(tf.constant(EPSILON), dataType); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); /* TODO need ability to walk back inputs @@ -545,7 +547,7 @@ public static Operand sparseCategorica predictions = tf.linalg.transpose(predictions, tf.constant(axisNew)); } - Operand iLabels = tf.dtypes.cast(labels, TInt64.DTYPE); + Operand iLabels = cast(tf, labels, TInt64.DTYPE); // Try to adjust the shape so that rank of labels = rank of logits - 1. Shape labelsShape = labels.asOutput().shape(); @@ -586,12 +588,12 @@ public static Operand sparseCategorica public static Operand squaredHinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = tf.dtypes.cast(labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand zero = tf.dtypes.cast(tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); tLabels = maybeConvertLabels(tf, tLabels); return tf.math.mean( @@ -636,8 +638,8 @@ public static Operand squaredHinge( private static Operand smoothLabelsBinaryX( Ops tf, Operand labels, float labelSmoothing) { DataType dataType = labels.asOutput().dataType(); - Operand oneMinusSmoothing = tf.dtypes.cast(tf.constant(1.f - labelSmoothing), dataType); - Operand halfSmoothing = tf.dtypes.cast(tf.constant(0.5F * labelSmoothing), dataType); + Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); + Operand halfSmoothing = cast(tf, tf.constant(0.5F * labelSmoothing), dataType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), halfSmoothing); } @@ -655,11 +657,11 @@ private static Operand smoothLabelsBinaryX( private static Operand smoothLabelsCatX( Ops tf, Operand labels, float labelSmoothing) { DataType dataType = labels.asOutput().dataType(); - Operand smoothing = tf.dtypes.cast(tf.constant(labelSmoothing), dataType); + Operand smoothing = cast(tf, tf.constant(labelSmoothing), dataType); Shape labelsShape = labels.asOutput().shape(); int numDims = labelsShape.numDimensions(); - Operand numClasses = tf.dtypes.cast(tf.constant(labelsShape.size(numDims - 1)), dataType); - Operand oneMinusSmoothing = tf.dtypes.cast(tf.constant(1.f - labelSmoothing), dataType); + Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), dataType); + Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); } @@ -678,7 +680,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i Operand invNorm = tf.math.rsqrt( tf.math.maximum( - squareSum, tf.dtypes.cast(tf.constant(1e-12F), x.asOutput().dataType()))); + squareSum, cast(tf, tf.constant(1e-12F), x.asOutput().dataType()))); return tf.math.mul(x, invNorm); } @@ -693,9 +695,9 @@ public static Operand l2Normalize(Ops tf, Operand x, i private static Operand maybeConvertLabels(Ops tf, Operand labels) { DataType dataType = labels.asOutput().dataType(); - Operand one = tf.dtypes.cast(tf.constant(1), dataType); - Operand zero = tf.dtypes.cast(tf.constant(0), dataType); - Operand two = tf.dtypes.cast(tf.constant(2), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); + Operand two = cast(tf, tf.constant(2), dataType); Operand areZeros = tf.math.equal(labels, zero); Operand areOnes = tf.math.equal(labels, one); Operand isBinary = diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 115831cceff..3064a66ae86 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the crossentropy loss between labels and predictions. @@ -205,8 +206,8 @@ public Operand call( getTF(), "predictions range check [0-1]", predictions, - getTF().dtypes.cast(getTF().constant(0), predictions.asOutput().dataType()), - getTF().dtypes.cast(getTF().constant(1), predictions.asOutput().dataType())); + cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), + cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 8e091015535..a7bdfa71c01 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesImpl; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the squared hinge loss between labels and predictions. @@ -127,12 +128,12 @@ public Operand call( @SuppressWarnings("unchecked") Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? (Operand)labels : - tf.dtypes.cast(labels, predictions.asOutput().dataType()); + cast(tf, labels, predictions.asOutput().dataType()); tLabels = LossesImpl.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - getTF().dtypes.cast(getTF().constant(new int[] { -1, 0, 1}), + cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.asOutput().dataType())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 9ef2522b6a5..2ccb48fff72 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -31,6 +31,8 @@ import java.util.Arrays; import java.util.Collections; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * These are helper methods for Losses and will be module private when Java modularity is applied to * TensorFlow Java. These methods should not be used outside of the Loss package. @@ -251,15 +253,15 @@ public static Operand computeWeightedLoss( Ops tf, Operand loss, Reduction reduction, Operand sampleWeight) { DataType dataType = loss.asOutput().dataType(); if (sampleWeight == null) { - sampleWeight = tf.dtypes.cast(tf.constant(1), dataType); + sampleWeight = cast(tf, tf.constant(1), dataType); } LossTuple result = squeezeOrExpandDimensions(tf, null, loss, sampleWeight); loss = result.getTarget(); sampleWeight = result.getSampleWeights(); - Operand weighted_losses = tf.math.mul(loss, tf.dtypes.cast(sampleWeight, dataType)); + Operand weighted_losses = tf.math.mul(loss, cast(tf, sampleWeight, dataType)); loss = reduceWeightedLoss(tf, weighted_losses, reduction); - return tf.dtypes.cast(loss, dataType); + return cast(tf, loss, dataType); } /** @@ -300,7 +302,7 @@ public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); return tf.math.divNoNan( - totalLoss, tf.dtypes.cast(tf.constant(numElements), losses.asOutput().dataType())); + totalLoss, cast(tf, tf.constant(numElements), losses.asOutput().dataType())); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java new file mode 100644 index 00000000000..aec75e6078a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.utils; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TType; + +/** A helper class for casting an Operand */ +public class CastHelper { + + /** + * Casts an operand to the desired type. + * + * @param tf The TensorFlow Ops + * @param value the value to be cast + * @param requiredType the required data type + * @param the required data type + * @param the original data type of the value + * @return the value cast to the required data type. + */ + @SuppressWarnings("unchecked") + public static Operand cast( + Ops tf, Operand value, DataType requiredType) { + return (value.asOutput().dataType() == requiredType) + ? (Operand) value + : tf.dtypes.cast(value, requiredType); + } +} From 02573b594ca552371b8f42fa9e53c019143e6931 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 9 Nov 2020 13:18:15 -0500 Subject: [PATCH 15/26] Fix JavaDoc, change snake case to camel case. --- .../framework/losses/BinaryCrossentropy.java | 8 ++++---- .../framework/losses/CategoricalCrossentropy.java | 8 ++++---- .../java/org/tensorflow/framework/losses/Hinge.java | 10 +++++----- .../java/org/tensorflow/framework/losses/Huber.java | 2 +- .../java/org/tensorflow/framework/losses/LogCosh.java | 2 +- .../java/org/tensorflow/framework/losses/Loss.java | 8 ++++---- .../java/org/tensorflow/framework/losses/Losses.java | 8 ++++---- .../losses/SparseCategoricalCrossentropy.java | 8 ++++---- .../org/tensorflow/framework/losses/SquaredHinge.java | 8 ++++---- .../tensorflow/framework/losses/impl/LossesImpl.java | 5 ++--- 10 files changed, 33 insertions(+), 34 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 26ddce312e6..470e9b4ca51 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -193,12 +193,12 @@ public BinaryCrossentropy( * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. - * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sample_weight vector. If the shape of sample_weight is + * corresponding element in the SampleWeights vector. If the shape of SampleWeights is * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. * @param The data type of the labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 5728e20fab2..1f234b6bb0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -234,12 +234,12 @@ public CategoricalCrossentropy( * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. - * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sample_weight vector. If the shape of sample_weight is + * corresponding element in the SampleWeights vector. If the shape of SampleWeights is * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. * @param The data type of the labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index b74c2fe3bd8..2cd9c75d1b2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -109,12 +109,12 @@ public Hinge(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. - * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sample_weight vector. If the shape of sample_weight is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * corresponding element in the SampleWeights vector. If the shape of SampleWeights is + * [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of + * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. * @param The data type of the labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index d4aef1bb2e7..5ee39efbdfa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -23,7 +23,7 @@ /** * Computes the Huber loss between labels and predictions. * - *

For each value x in error = y_true - y_pred: + *

For each value x in error = labels - predictions: * *

  *     loss = 0.5 * x^2                  if |x| <= d
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
index c573ebe9192..850b1f66a24 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
@@ -24,7 +24,7 @@
  * Computes Computes the logarithm of the hyperbolic cosine of the prediction error.
  *
  * 

logcosh = log((exp(x) + exp(-x))/2), where x is the error - * predictions - y_true. + * predictions - labels. * *

Standalone usage: * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 4b1feb03d16..b9a08ad2b0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -75,12 +75,12 @@ public Operand call(Operand labels, * * @param labels the truth values or labels * @param predictions the predictions - * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sample_weight vector. If the shape of sample_weight is + * corresponding element in the SampleWeights vector. If the shape of SampleWeights is * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. * @param The data type of the labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index ff0b513c4af..85b272fe0a0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -606,16 +606,16 @@ public static Operand squaredHinge( // * // * @param labels the truth values or labels // * @param predictions the predictions - // * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar + // * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar // is - // * provided, then the loss is simply scaled by the given value. If sample_weight is a + // * provided, then the loss is simply scaled by the given value. If SampleWeights is a // tensor // * of size [batch_size], then the total loss for each sample of the batch is rescaled by // the - // * corresponding element in the sample_weight vector. If the shape of sample_weight is + // * corresponding element in the SampleWeights vector. If the shape of SampleWeights is // * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element // of - // * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all + // * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all // loss // * functions reduce by 1 dimension, usually axis=-1.) // * @param The data type of the predictions, sampleWeights and loss. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 3064a66ae86..953b693c9f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -183,12 +183,12 @@ public SparseCategoricalCrossentropy( * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. - * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sample_weight vector. If the shape of sample_weight is + * corresponding element in the SampleWeights vector. If the shape of SampleWeights is * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. * @param The data type of the labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index a7bdfa71c01..cf2097c8de4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -110,12 +110,12 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. - * @param sampleWeights Optional sample_weight acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sample_weight is a tensor + * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sample_weight vector. If the shape of sample_weight is + * corresponding element in the SampleWeights vector. If the shape of SampleWeights is * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss + * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. * @param The data type of the labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java index 2ccb48fff72..92acafc9039 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java @@ -100,7 +100,6 @@ public static LossTuple squeezeOrExpandDimensions( if (labelsRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) { // Use static rank for 'label' and 'prediction'. if (predictionsRank - labelsRank != 1 || predictionsShape.size(-1) == 1) { - // label, prediction = confusion_matrix.remove_squeezable_dimensions(label, prediction) lossTuple = removeSqueezableDimensions(tf, labels, predictions); } } else { // use dynamic rank @@ -259,8 +258,8 @@ public static Operand computeWeightedLoss( loss = result.getTarget(); sampleWeight = result.getSampleWeights(); - Operand weighted_losses = tf.math.mul(loss, cast(tf, sampleWeight, dataType)); - loss = reduceWeightedLoss(tf, weighted_losses, reduction); + Operand weightedLosses = tf.math.mul(loss, cast(tf, sampleWeight, dataType)); + loss = reduceWeightedLoss(tf, weightedLosses, reduction); return cast(tf, loss, dataType); } From 0bf49fe3203eb5f810ea09e0322fd36b6945856c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 11 Nov 2020 12:01:22 -0500 Subject: [PATCH 16/26] Change class LossesImpl to LossesHelper --- .../framework/losses/BinaryCrossentropy.java | 6 ++-- .../losses/CategoricalCrossentropy.java | 6 ++-- .../framework/losses/CategoricalHinge.java | 4 +-- .../framework/losses/CosineSimilarity.java | 4 +-- .../tensorflow/framework/losses/Hinge.java | 6 ++-- .../tensorflow/framework/losses/Huber.java | 4 +-- .../framework/losses/KLDivergence.java | 4 +-- .../tensorflow/framework/losses/LogCosh.java | 4 +-- .../tensorflow/framework/losses/Losses.java | 30 +++++++++---------- .../framework/losses/MeanAbsoluteError.java | 4 +-- .../losses/MeanAbsolutePercentageError.java | 4 +-- .../framework/losses/MeanSquaredError.java | 4 +-- .../losses/MeanSquaredLogarithmicError.java | 4 +-- .../tensorflow/framework/losses/Poisson.java | 4 +-- .../losses/SparseCategoricalCrossentropy.java | 6 ++-- .../framework/losses/SquaredHinge.java | 6 ++-- .../{LossesImpl.java => LossesHelper.java} | 6 ++-- 17 files changed, 53 insertions(+), 53 deletions(-) rename tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/{LossesImpl.java => LossesHelper.java} (99%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 470e9b4ca51..a226170513a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -212,7 +212,7 @@ public Operand call( if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesImpl.rangeCheck( + LossesHelper.rangeCheck( getTF(), "predictions range check [0-1]", predictions, @@ -225,6 +225,6 @@ public Operand call( Operand losses = Losses.binaryCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 1f234b6bb0a..e6665ddc086 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -253,7 +253,7 @@ public Operand call( if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesImpl.rangeCheck( + LossesHelper.rangeCheck( getTF(), "predictions range check [0-1]", predictions, @@ -265,6 +265,6 @@ public Operand call( } Operand losses = Losses.categoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index def66715869..c60628fd22e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -103,6 +103,6 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 0813148e17a..f6dc5c7e5fd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -175,6 +175,6 @@ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 2cd9c75d1b2..31b155a76a2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -128,7 +128,7 @@ public Operand call( Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? (Operand)labels : cast(tf, labels, predictions.asOutput().dataType()); - tLabels = LossesImpl.valueCheck( + tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, @@ -136,6 +136,6 @@ public Operand call( predictions.asOutput().dataType())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 5ee39efbdfa..69487405148 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -134,6 +134,6 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 6433c753c7c..52af31c820b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -103,6 +103,6 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 850b1f66a24..97ca2b99cff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -109,6 +109,6 @@ public LogCosh(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 85b272fe0a0..ba641d19362 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -18,7 +18,7 @@ import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceAll; @@ -53,7 +53,7 @@ public class Losses { public static Operand meanAbsoluteError( Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); - LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); return tf.math.mean( @@ -75,7 +75,7 @@ public static Operand meanAbsoluteErro public static Operand meanSquaredError( Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); - LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); return tf.math.mean(tf.math.squaredDifference(predictions, tLabels), tf.constant(-1)); @@ -97,7 +97,7 @@ public static Operand meanAbsolutePerc Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); Operand diff = @@ -126,7 +126,7 @@ public static Operand meanSquaredLogar Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -157,7 +157,7 @@ public static Operand binaryCrossentro Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -238,7 +238,7 @@ public static Operand categoricalCross int axis) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple ops = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -291,7 +291,7 @@ public static Operand categoricalHinge Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand one = cast(tf, tf.constant(1), dataType); @@ -333,7 +333,7 @@ public static Operand cosineSimilarity Ops tf, Operand labels, Operand predictions, int axis) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); @@ -360,7 +360,7 @@ public static Operand hinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand one = cast(tf, tf.constant(1), dataType); @@ -396,7 +396,7 @@ public static Operand huber( Ops tf, Operand labels, Operand predictions, float delta) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); @@ -427,7 +427,7 @@ public static Operand kullbackLeiblerD Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand one = cast(tf, tf.constant(1), dataType); @@ -457,7 +457,7 @@ public static Operand logCosh( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand minusTwo = cast(tf, tf.constant(-2), dataType); @@ -485,7 +485,7 @@ public static Operand poisson( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); @@ -589,7 +589,7 @@ public static Operand squaredHinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); Operand tLabels = cast(tf, labels, dataType); - LossTuple lossTuple = LossesImpl.squeezeOrExpandDimensions(tf, tLabels, predictions, null); + LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand one = cast(tf, tf.constant(1), dataType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index c4062330ea0..3de27026944 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -99,6 +99,6 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 0e8d1c76a39..009c22f1b8a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -99,6 +99,6 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index fe06ba6f853..e218bb38616 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -99,6 +99,6 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index c2b48a0a111..c673ca292d6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -99,6 +99,6 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 273b80824b1..6e61502fca7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -108,6 +108,6 @@ public Poisson(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.poisson(getTF(), labels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 953b693c9f2..3086b6200a2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -202,7 +202,7 @@ public Operand call( if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesImpl.rangeCheck( + LossesHelper.rangeCheck( getTF(), "predictions range check [0-1]", predictions, @@ -214,6 +214,6 @@ public Operand call( } Operand losses = Losses.sparseCategoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, axis); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index cf2097c8de4..25c75e7732d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -16,7 +16,7 @@ package org.tensorflow.framework.losses; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossesImpl; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -129,13 +129,13 @@ public Operand call( Operand tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ? (Operand)labels : cast(tf, labels, predictions.asOutput().dataType()); - tLabels = LossesImpl.valueCheck( + tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.asOutput().dataType())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); - return LossesImpl.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); + return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java similarity index 99% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 92acafc9039..3dc1c66be4d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesImpl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -34,10 +34,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * These are helper methods for Losses and will be module private when Java modularity is applied to - * TensorFlow Java. These methods should not be used outside of the Loss package. + * These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to + * TensorFlow Java. These methods should not be used outside of the losses and metrics packages. */ -public class LossesImpl { +public class LossesHelper { /** * Squeeze or expand last dimension if needed with a sampleWeights of one. From 0eae9ee1b2ae6aaccf6c9216bbb90f8bcda0a9a6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 12 Nov 2020 10:09:38 -0500 Subject: [PATCH 17/26] Remove commented out JavaDoc --- .../tensorflow/framework/losses/Losses.java | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index ba641d19362..0d1f5497e42 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -601,27 +601,6 @@ public static Operand squaredHinge( tf.constant(-1)); } - // private methods/** - // * Calculates the loss - // * - // * @param labels the truth values or labels - // * @param predictions the predictions - // * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar - // is - // * provided, then the loss is simply scaled by the given value. If SampleWeights is a - // tensor - // * of size [batch_size], then the total loss for each sample of the batch is rescaled by - // the - // * corresponding element in the SampleWeights vector. If the shape of SampleWeights is - // * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element - // of - // * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all - // loss - // * functions reduce by 1 dimension, usually axis=-1.) - // * @param The data type of the predictions, sampleWeights and loss. - // * @param The data type of the labels. - // * @return the loss - // * /** * Smooths binary labels From b211937c946a67c6f3830e70bdccf97a54cd8051 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 13 Nov 2020 09:56:15 -0500 Subject: [PATCH 18/26] Changed method name from smoothLabelsBinaryX to smoothBinaryLabels, smoothLabelsCatX to smoothCategoricalLabels. Added clarification oin JavaDoc for cosineSimilarity to describe the difference between the mathematical definition for cosine similarity and the loss definition. --- .../tensorflow/framework/losses/Losses.java | 124 +++++++++--------- 1 file changed, 61 insertions(+), 63 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 0d1f5497e42..6b7c07d4ec0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -52,7 +52,7 @@ public class Losses { */ public static Operand meanAbsoluteError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -74,7 +74,7 @@ public static Operand meanAbsoluteErro */ public static Operand meanSquaredError( Ops tf, Operand labels, Operand predictions) { - Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); + Operand tLabels = cast(tf, labels, predictions.asOutput().dataType()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -96,7 +96,7 @@ public static Operand meanSquaredError public static Operand meanAbsolutePercentageError( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); @@ -104,10 +104,8 @@ public static Operand meanAbsolutePerc tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum( - tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), dataType)))); - return tf.math.mul( - cast(tf, tf.constant(100), dataType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), dataType)))); + return tf.math.mul(cast(tf, tf.constant(100), dataType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -125,13 +123,13 @@ public static Operand meanAbsolutePerc public static Operand meanSquaredLogarithmicError( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); - Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); Operand firstLog = tf.math.log(tf.math.add(tf.math.maximum(predictions, epsilonConst), one)); Operand secondLog = tf.math.log(tf.math.add(tf.math.maximum(tLabels, epsilonConst), one)); @@ -156,13 +154,13 @@ public static Operand meanSquaredLogar public static Operand binaryCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); if (labelSmoothing != 0.0f) { - tLabels = smoothLabelsBinaryX(tf, tLabels, labelSmoothing); + tLabels = smoothBinaryLabels(tf, tLabels, labelSmoothing); } Operand bce = binaryCrossentropyHelper(tf, tLabels, predictions, fromLogits); return tf.math.mean(bce, tf.constant(-1)); @@ -181,9 +179,7 @@ public static Operand binaryCrossentro */ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { - if (fromLogits) - return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - + if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { @@ -199,8 +195,8 @@ private static Operand binaryCrossentropyHelper( */ DataType dataType = output.asOutput().dataType(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); output = tf.clipByValue(output, epsilonConst, oneMinusEpsilonConst); @@ -237,13 +233,13 @@ public static Operand categoricalCross float labelSmoothing, int axis) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); tLabels = ops.getLabels(); if (labelSmoothing != 0.0f) { - tLabels = smoothLabelsCatX(tf, tLabels, labelSmoothing); + tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing); } if (fromLogits) { return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); @@ -261,8 +257,8 @@ public static Operand categoricalCross } */ - Operand one = cast(tf, tf.constant(1), dataType); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); predictions = tf.math.div( @@ -290,12 +286,12 @@ public static Operand categoricalCross public static Operand categoricalHinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); Operand pos = tf.reduceSum( @@ -313,11 +309,15 @@ public static Operand categoricalHinge /** * Computes the cosine similarity loss between labels and predictions. * - *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 - * indicates orthogonality and values closer to -1 indicate greater similarity. The values closer - * to 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where - * you try to maximize the proximity between predictions and targets. If either labels or - * predictions is a zero vector, cosine similarity will be 0 regardless of the proximity between + *

Note that it is a number between -1 and 1, which is different from + * the mathematical definition of cosine similarity where 1 represents similar + * vectors, and 0 represents dissimilar vectors. In this function, the numbers are + * inverted in a range of -1 to 1. When it is a negative number between + * -1 and 0, 0 indicates orthogonality and values closer to + * -1 indicate greater similarity. The values closer to 1 indicate + * greater dissimilarity. This makes it usable as a loss function in a setting where you try to + * maximize the proximity between predictions and targets. If either labels or predictions is a + * zero vector, cosine similarity will be 0 regardless of the proximity between * predictions and targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) @@ -332,7 +332,7 @@ public static Operand categoricalHinge public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int axis) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); @@ -359,12 +359,12 @@ public static Operand cosineSimilarity public static Operand hinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); tLabels = maybeConvertLabels(tf, tLabels); @@ -395,14 +395,14 @@ public static Operand hinge( public static Operand huber( Ops tf, Operand labels, Operand predictions, float delta) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); Operand error = tf.math.sub(predictions, tLabels); - Operand deltaConst = cast(tf, tf.constant(delta), dataType); - Operand point5 = cast(tf, tf.constant(0.5), dataType); + Operand deltaConst = cast(tf, tf.constant(delta), dataType); + Operand point5 = cast(tf, tf.constant(0.5), dataType); Operand absError = tf.math.abs(error); Operand quadratic = tf.math.minimum(absError, deltaConst); Operand linear = tf.math.sub(absError, quadratic); @@ -426,12 +426,12 @@ public static Operand huber( public static Operand kullbackLeiblerDivergence( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); tLabels = tf.clipByValue(tLabels, epsilonConst, one); predictions = tf.clipByValue(predictions, epsilonConst, one); @@ -456,12 +456,12 @@ public static Operand kullbackLeiblerD public static Operand logCosh( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand minusTwo = cast(tf, tf.constant(-2), dataType); - Operand two = cast(tf, tf.constant(2), dataType); + Operand minusTwo = cast(tf, tf.constant(-2), dataType); + Operand two = cast(tf, tf.constant(2), dataType); Operand diff = tf.math.sub(predictions, tLabels); Softplus softplus = tf.math.softplus(tf.math.mul(minusTwo, diff)); @@ -484,11 +484,11 @@ public static Operand logCosh( public static Operand poisson( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); return tf.math.mean( tf.math.sub( @@ -511,8 +511,8 @@ public static Operand poisson( public static Operand sparseCategoricalCrossentropy( Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { DataType dataType = predictions.asOutput().dataType(); - Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); - Operand one = cast(tf, tf.constant(1), dataType); + Operand epsilonConst = cast(tf, tf.constant(EPSILON), dataType); + Operand one = cast(tf, tf.constant(1), dataType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); /* TODO need ability to walk back inputs @@ -547,7 +547,7 @@ public static Operand sparseCategorica predictions = tf.linalg.transpose(predictions, tf.constant(axisNew)); } - Operand iLabels = cast(tf, labels, TInt64.DTYPE); + Operand iLabels = cast(tf, labels, TInt64.DTYPE); // Try to adjust the shape so that rank of labels = rank of logits - 1. Shape labelsShape = labels.asOutput().shape(); @@ -588,12 +588,12 @@ public static Operand sparseCategorica public static Operand squaredHinge( Ops tf, Operand labels, Operand predictions) { DataType dataType = predictions.asOutput().dataType(); - Operand tLabels = cast(tf, labels, dataType); + Operand tLabels = cast(tf, labels, dataType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); tLabels = maybeConvertLabels(tf, tLabels); return tf.math.mean( @@ -601,7 +601,6 @@ public static Operand squaredHinge( tf.constant(-1)); } - /** * Smooths binary labels * @@ -614,11 +613,11 @@ public static Operand squaredHinge( * @param the data type of the labels * @return the smoothed binary labels */ - private static Operand smoothLabelsBinaryX( + private static Operand smoothBinaryLabels( Ops tf, Operand labels, float labelSmoothing) { DataType dataType = labels.asOutput().dataType(); - Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); - Operand halfSmoothing = cast(tf, tf.constant(0.5F * labelSmoothing), dataType); + Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); + Operand halfSmoothing = cast(tf, tf.constant(0.5F * labelSmoothing), dataType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), halfSmoothing); } @@ -633,14 +632,14 @@ private static Operand smoothLabelsBinaryX( * @param the data type of the labels * @return the smoothed categorical labels */ - private static Operand smoothLabelsCatX( + private static Operand smoothCategoricalLabels( Ops tf, Operand labels, float labelSmoothing) { DataType dataType = labels.asOutput().dataType(); - Operand smoothing = cast(tf, tf.constant(labelSmoothing), dataType); + Operand smoothing = cast(tf, tf.constant(labelSmoothing), dataType); Shape labelsShape = labels.asOutput().shape(); int numDims = labelsShape.numDimensions(); - Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), dataType); - Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); + Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), dataType); + Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); } @@ -658,8 +657,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = tf.math.rsqrt( - tf.math.maximum( - squareSum, cast(tf, tf.constant(1e-12F), x.asOutput().dataType()))); + tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.asOutput().dataType()))); return tf.math.mul(x, invNorm); } @@ -674,9 +672,9 @@ public static Operand l2Normalize(Ops tf, Operand x, i private static Operand maybeConvertLabels(Ops tf, Operand labels) { DataType dataType = labels.asOutput().dataType(); - Operand one = cast(tf, tf.constant(1), dataType); - Operand zero = cast(tf, tf.constant(0), dataType); - Operand two = cast(tf, tf.constant(2), dataType); + Operand one = cast(tf, tf.constant(1), dataType); + Operand zero = cast(tf, tf.constant(0), dataType); + Operand two = cast(tf, tf.constant(2), dataType); Operand areZeros = tf.math.equal(labels, zero); Operand areOnes = tf.math.equal(labels, one); Operand isBinary = From 3e0669e03b4c2a5bab5b4ffc0e2387dc0adccefb Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 13 Nov 2020 09:56:54 -0500 Subject: [PATCH 19/26] Fixed JavaDoc for labelSmoothing --- .../losses/CategoricalCrossentropy.java | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index e6665ddc086..3306d16b1a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -137,8 +138,8 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits) { /** * Creates a categorical cross entropy Loss using {@link #LABEL_SMOOTHING_DEFAULT} for - * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of {@link - * #DEFAULT_AXIS} + * labelSmoothing, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and a channel axis of + * {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss @@ -169,10 +170,9 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the - * loss between the predicted labels and a smoothed version of the true labels, where the - * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to - * heavier smoothing. + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); @@ -184,10 +184,9 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the - * loss between the predicted labels and a smoothed version of the true labels, where the - * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to - * heavier smoothing. + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * alue of 0.1 for label 0 and 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -201,10 +200,9 @@ public CategoricalCrossentropy( * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the - * loss between the predicted labels and a smoothed version of the true labels, where the - * smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to - * heavier smoothing. + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' * and axis=1 corresponds to data format 'Channels First'. @@ -218,8 +216,9 @@ public CategoricalCrossentropy( Reduction reduction, int axis) { super(tf, name, reduction); - if(labelSmoothing < 0 || labelSmoothing > 1) - throw new IllegalArgumentException("labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); + if (labelSmoothing < 0 || labelSmoothing > 1) + throw new IllegalArgumentException( + "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); this.fromLogits = fromLogits; this.labelSmoothing = labelSmoothing; this.axis = axis; @@ -228,9 +227,10 @@ public CategoricalCrossentropy( /** * Generates an Operand that calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. @@ -248,23 +248,24 @@ public CategoricalCrossentropy( */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesHelper.rangeCheck( - getTF(), - "predictions range check [0-1]", - predictions, - cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), - cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); + LossesHelper.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()), + cast(getTF(), getTF().constant(1), predictions.asOutput().dataType())); } else { lPredictions = predictions; } Operand losses = - Losses.categoricalCrossentropy(getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); + Losses.categoricalCrossentropy( + getTF(), labels, lPredictions, fromLogits, labelSmoothing, axis); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } From 914f16f4473512c8b5ef9df8ca43074b82d3edd0 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 13 Nov 2020 09:57:43 -0500 Subject: [PATCH 20/26] Fixed JavaDoc to change label_smoothing to labelSmoothing. --- .../tensorflow/framework/losses/BinaryCrossentropy.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index a226170513a..56b06cce14c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -125,7 +125,7 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits) { * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, - * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { @@ -140,7 +140,7 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, - * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. */ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { @@ -154,7 +154,7 @@ public BinaryCrossentropy(Ops tf, String name, boolean fromLogits, float labelSm * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, - * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. */ @@ -170,7 +170,7 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Redu * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, - * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * where the smoothing squeezes the labels towards 0.5. Larger values of labelSmoothing * correspond to heavier smoothing. * @param reduction Type of Reduction to apply to the loss. * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. From 7eefbb7f197c731a7d304d055fd242d1acd9835f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 13 Nov 2020 09:58:19 -0500 Subject: [PATCH 21/26] Fix formatting --- .../framework/losses/BinaryCrossentropy.java | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 56b06cce14c..6e3e7fa321c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -178,8 +178,9 @@ public BinaryCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing, Redu public BinaryCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, Reduction reduction) { super(tf, name, reduction); - if(labelSmoothing < 0 || labelSmoothing > 1) - throw new IllegalArgumentException("labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); + if (labelSmoothing < 0 || labelSmoothing > 1) + throw new IllegalArgumentException( + "labelSmoothing must be >= 0. and <= 1, found " + labelSmoothing); this.fromLogits = fromLogits; this.labelSmoothing = labelSmoothing; } @@ -187,9 +188,10 @@ public BinaryCrossentropy( /** * Generates an Operand that calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. From b87ad16118442643b845bb4e24a0145eea0056fb Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 13 Nov 2020 13:11:21 -0500 Subject: [PATCH 22/26] replace label_smoothing with labelSmoothing. fix typo error in JavaDoc comment --- .../framework/losses/CategoricalCrossentropy.java | 10 +++++----- .../java/org/tensorflow/framework/losses/Losses.java | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 3306d16b1a5..522446bec4d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -156,7 +156,7 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { @@ -171,7 +171,7 @@ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { @@ -185,8 +185,8 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a - * alue of 0.1 for label 0 and 0.9 for label 1 + * confidence on label values are relaxed. e.g. x=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -201,7 +201,7 @@ public CategoricalCrossentropy( * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 6b7c07d4ec0..3b61b0f8eab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -219,7 +219,7 @@ private static Operand binaryCrossentropyHelper( * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels @@ -627,7 +627,7 @@ private static Operand smoothBinaryLabels( * @param tf the TensorFlow Ops * @param labels true targets * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. label_smoothing=0.2 means that we will use a + * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 * @param the data type of the labels * @return the smoothed categorical labels From c43cd21165c67d1972bc693a5d4a9ccdb49395eb Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 16 Nov 2020 18:17:39 -0500 Subject: [PATCH 23/26] Add copyright to test cases --- .../losses/BinaryCrossentropyTest.java | 17 +- .../losses/CategoricalCrossentropyTest.java | 68 ++-- .../losses/CategoricalHingeTest.java | 232 +++++++------ .../losses/CosineSimilarityTest.java | 15 + .../framework/losses/HingeTest.java | 17 +- .../framework/losses/HuberTest.java | 15 + .../framework/losses/KLDivergenceTest.java | 192 +++++----- .../framework/losses/LogCoshTest.java | 191 +++++----- .../losses/MeanAbsoluteErrorTest.java | 15 + .../MeanAbsolutePercentageErrorTest.java | 15 + .../losses/MeanSquaredErrorTest.java | 15 + .../MeanSquaredLogarithmicErrorTest.java | 328 +++++++++--------- .../framework/losses/PoissonTest.java | 191 +++++----- .../SparseCategoricalCrossentropyTest.java | 53 ++- .../framework/losses/SquaredHingeTest.java | 201 ++++++----- 15 files changed, 892 insertions(+), 673 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index 86401f03f5d..0928a4ad10c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -45,7 +60,7 @@ public void testAllCorrectUnweighted() { public void testInvalidPredictionsRange() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Class catchClass = + Class catchClass = tfMode == TestSession.Mode.EAGER ? IllegalArgumentException.class : org.tensorflow.exceptions.TFInvalidArgumentException.class; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 0e273fd6a8c..36af5bd51ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -6,7 +21,6 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; @@ -60,33 +74,33 @@ public void testAllCorrectUnweighted() { public void testInvalidPredictionsRange() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Class catchClass = - tfMode == TestSession.Mode.EAGER - ? IllegalArgumentException.class - : org.tensorflow.exceptions.TFInvalidArgumentException.class; + Class catchClass = + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; assertThrows( - catchClass, - () -> { - Ops tf = testSession.getTF(); - CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); - float[] trueArray = { - 1L, 0L, 0L, - 0L, 1L, 0L, - 0L, 0L, 1L - }; - float[] predArray = { - -1.F, 0.F, 0.F, - 0.F, 1.F, 0.F, - 0.F, 0.F, 1.F - }; - Operand yTrue = - tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); - Operand yPred = - tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - - Operand loss = instance.call(yTrue, yPred); - testSession.run(loss); - }); + catchClass, + () -> { + Ops tf = testSession.getTF(); + CategoricalCrossentropy instance = new CategoricalCrossentropy(tf); + float[] trueArray = { + 1L, 0L, 0L, + 0L, 1L, 0L, + 0L, 0L, 1L + }; + float[] predArray = { + -1.F, 0.F, 0.F, + 0.F, 1.F, 0.F, + 0.F, 0.F, 1.F + }; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java index b455d58740b..8923a299774 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -9,123 +24,116 @@ import org.tensorflow.types.TInt32; public class CategoricalHingeTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - /** - * Test of call method, of class CategoricalHinge. - */ - @Test - public void testReductionNone() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf, Reduction.NONE); - int[] trueArray = {1, 9, 2, -5}; - float[] predArray = {4f, 8f, 12f, 8f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); - Float[] expected = {0.0f, 65.0f}; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class CategoricalHinge. */ + @Test + public void testReductionNone() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf, Reduction.NONE); + int[] trueArray = {1, 9, 2, -5}; + float[] predArray = {4f, 8f, 12f, 8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred); + Float[] expected = {0.0f, 65.0f}; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class CategoricalHinge. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); - int[] trueArray = {1, 9, 2, -5}; - float[] predArray = {4f, 8f, 12f, 8f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred); - float expected = 32.5f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class CategoricalHinge. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5}; + float[] predArray = {4f, 8f, 12f, 8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred); + float expected = 32.5f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class CategoricalHinge. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); - int[] trueArray = {1, 9, 2, -5, -2, 6}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 83.95f; - testSession.evaluate(expected, loss); + /** Test of call method, of class CategoricalHinge. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 83.95f; + testSession.evaluate(expected, loss); - Operand loss2 = instance.call(yTrue, yPred, sampleWeight); - testSession.evaluate(loss, loss2); - } - } + Operand loss2 = instance.call(yTrue, yPred, sampleWeight); + testSession.evaluate(loss, loss2); + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); - int[] trueArray = {1, 9, 2, -5, -2, 6}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] weightsNp = {1.2f, 3.4f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 124.1f; - testSession.evaluate(expected, loss); - } - } - - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); - int[] trueArray = {1, 9, 2, -5, -2, 6}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(0f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - - } - } - - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - CategoricalHinge instance = new CategoricalHinge(tf); - int[] trueArray = {1, 9, 2, -5, -2, 6}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] weightsNp = {3, 6, 5, 0, 4, 2}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Operand sampleWeight = tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 4.0f; - testSession.evaluate(expected, loss); - - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] weightsNp = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 124.1f; + testSession.evaluate(expected, loss); + } + } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + CategoricalHinge instance = new CategoricalHinge(tf); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] weightsNp = {3, 6, 5, 0, 4, 2}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(weightsNp), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 4.0f; + testSession.evaluate(expected, loss); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java index ca7aea553d1..dea995ad987 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index 8951883a0a3..3e86b303b83 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -34,7 +49,7 @@ public void testUnweighted() { public void testInvalidLabelValue() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Class catchClass = + Class catchClass = tfMode == TestSession.Mode.EAGER ? IllegalArgumentException.class : org.tensorflow.exceptions.TFInvalidArgumentException.class; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java index d7acab0126f..9ab11e0cf0e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java index d40b4286f3d..7875a63f1d9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -8,99 +23,98 @@ import org.tensorflow.types.TFloat32; public class KLDivergenceTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - /** - * Test of call method, of class KLDivergence. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); - float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; - float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); - float expected = 0.5960738398643668f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class KLDivergence. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.5960738398643668f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class KLDivergence. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); - float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; - float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 1.3709698316880434f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class KLDivergence. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 1.3709698316880434f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); - float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; - float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; - float[] sampleArray = {1.2f, 3.4f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 2.0075711736936492f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 2.0075711736936492f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf); - float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; - float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - KLDivergence instance = new KLDivergence(tf, Reduction.AUTO); - float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; - float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; - float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - - float expected = 0.2495994912084345f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + KLDivergence instance = new KLDivergence(tf, Reduction.AUTO); + float[] predArray = {.4f, .9f, .12f, .36f, .3f, .4f}; + float[] trueArray = {.5f, .8f, .12f, .7f, .43f, .8f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.2495994912084345f; + testSession.evaluate(expected, loss); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java index 0828471062a..1d06669b731 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -8,98 +23,98 @@ import org.tensorflow.types.TFloat32; public class LogCoshTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - /** - * Test of call method, of class LogCosh. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); - float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); - float expected = 4.829245330860459f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class LogCosh. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 4.829245330860459f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class LogCosh. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); - float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 11.107264260979056f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class LogCosh. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 11.107264260979056f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); - float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {1.2f, 3.4f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 12.001114667519486f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 12.001114667519486f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf); - float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - LogCosh instance = new LogCosh(tf, Reduction.AUTO); - float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + LogCosh instance = new LogCosh(tf, Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 11.653484271934046f; - testSession.evaluate(expected, loss); - } - } + float expected = 11.653484271934046f; + testSession.evaluate(expected, loss); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java index 91747900c2c..f069091b7a2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Assertions; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java index 5c2521f900c..debf6b967b9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java index 02ef7621e38..b2c7ee46837 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Assertions; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java index 66caa215219..71c840f7284 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Assertions; @@ -9,171 +24,172 @@ import org.tensorflow.types.TFloat32; public class MeanSquaredLogarithmicErrorTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - /** - * Test of call method, of class MeanSquaredLogarithmicError. - */ - @Test - public void testAllCorrectUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yTrue); - float expected = 0.0f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class MeanSquaredLogarithmicError. */ + @Test + public void testAllCorrectUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yTrue); + float expected = 0.0f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class MeanSquaredLogarithmicError. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); - float expected = 1.4370421f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class MeanSquaredLogarithmicError. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = 1.4370421f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class MeanSquaredLogarithmicError. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 3.3051968f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class MeanSquaredLogarithmicError. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 3.3051968f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {1.2f, 3.4f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 3.7856376f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 3.7856376f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.AUTO); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.AUTO); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 2.647374f; - testSession.evaluate(expected, loss); - } - } + float expected = 2.647374f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testInvalidSampleWeight() { - for (TestSession.Mode tfMode : tfModes) - Assertions.assertThrows( - IllegalArgumentException.class, - () -> { + @Test + public void testInvalidSampleWeight() { + for (TestSession.Mode tfMode : tfModes) + Assertions.assertThrows( + IllegalArgumentException.class, + () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {3f, 6f, 5f, 0f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 83f / 6f; - testSession.evaluate(expected, loss); - }}); - } + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 2))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 83f / 6f; + testSession.evaluate(expected, loss); + } + }); + } - @Test - public void testNoReduction() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.NONE); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - Float[] expected = {2.3006392f, 4.3097544f}; - testSession.evaluate(expected, loss); - } - } + @Test + public void testNoReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.NONE); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {2.3006392f, 4.3097544f}; + testSession.evaluate(expected, loss); + } + } - @Test - public void testSumReduction() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.SUM); - float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; - float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - Float[] expected = {6.6103935f}; - testSession.evaluate(expected, loss); - } - } - + @Test + public void testSumReduction() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError(tf, Reduction.SUM); + float[] trueArray = {1f, 9f, 2f, -5f, -2f, 6f}; + float[] predArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + Float[] expected = {6.6103935f}; + testSession.evaluate(expected, loss); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java index 0a086a37b96..e156c786a81 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -8,98 +23,98 @@ import org.tensorflow.types.TFloat32; public class PoissonTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - /** - * Test of call method, of class Poisson. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); - float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred); - float expected = -3.306581945521002f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class Poisson. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred); + float expected = -3.306581945521002f; + testSession.evaluate(expected, loss); + } + } - /** - * Test of call method, of class Poisson. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); - float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = -7.605138474698304f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class Poisson. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -7.605138474698304f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); - float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {1.2f, 3.4f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = -6.147338926788071f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {1.2f, 3.4f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = -6.147338926788071f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf); - float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Poisson instance = new Poisson(tf, Reduction.AUTO); - float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; - float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; - float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Poisson instance = new Poisson(tf, Reduction.AUTO); + float[] predArray = {1f, 9f, 2f, 5f, 2f, 6f}; + float[] trueArray = {4f, 8f, 12f, 8f, 1f, 3f}; + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3, 1))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 3))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = -12.263126013890561f; - testSession.evaluate(expected, loss); - } - } + float expected = -12.263126013890561f; + testSession.evaluate(expected, loss); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index 75109310a5a..26f4050a52d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -54,25 +69,27 @@ public void testInvalidPredictionsRange() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { Class catchClass = - tfMode == TestSession.Mode.EAGER - ? IllegalArgumentException.class - : org.tensorflow.exceptions.TFInvalidArgumentException.class; + tfMode == TestSession.Mode.EAGER + ? IllegalArgumentException.class + : org.tensorflow.exceptions.TFInvalidArgumentException.class; assertThrows( - catchClass, - () -> { - Ops tf = testSession.getTF(); - SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); - int[] trueArray = {0, 1, 2}; - float[] predArray = { - 1.9f, .05f, .05f, - .5f, .89f, .6f, - .05f, .01f, .94f - }; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); - Operand loss = instance.call(yTrue, yPred); - testSession.run(loss); - }); + catchClass, + () -> { + Ops tf = testSession.getTF(); + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy(tf); + int[] trueArray = {0, 1, 2}; + float[] predArray = { + 1.9f, .05f, .05f, + .5f, .89f, .6f, + .05f, .01f, .94f + }; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); + Operand loss = instance.call(yTrue, yPred); + testSession.run(loss); + }); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index ca5bd3eb759..58da30dc598 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; @@ -10,32 +25,30 @@ import static org.junit.jupiter.api.Assertions.assertThrows; public class SquaredHingeTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - /** - * Test of call method, of class SquaredHinge. - */ - @Test - public void testUnweighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); - float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred); - float expected = 0.364062f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class SquaredHinge. */ + @Test + public void testUnweighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred); + float expected = 0.364062f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testInvalidLabelValue() { + @Test + public void testInvalidLabelValue() { for (TestSession.Mode tfMode : tfModes) try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Class catchClass = + Class catchClass = tfMode == TestSession.Mode.EAGER ? IllegalArgumentException.class : org.tensorflow.exceptions.TFInvalidArgumentException.class; @@ -53,80 +66,82 @@ public void testInvalidLabelValue() { Operand loss = instance.call(yTrue, yPred); testSession.run(loss); }); - } - } + } + } - /** - * Test of call method, of class SquaredHinge. - */ - @Test - public void testScalarWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); - float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand sampleWeight = tf.constant(2.3f); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0.8373437f; - testSession.evaluate(expected, loss); - } - } + /** Test of call method, of class SquaredHinge. */ + @Test + public void testScalarWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(2.3f); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.8373437f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testSampleWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); - float[] sampleArray = {1.2f, 3.4f}; - float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0.7043125f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testSampleWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] sampleArray = {1.2f, 3.4f}; + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 1))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0.7043125f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testZeroWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf); - float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Operand sampleWeight = tf.constant(0.F); - Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 0f; - testSession.evaluate(expected, loss); - } - } + @Test + public void testZeroWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Operand sampleWeight = tf.constant(0.F); + Operand loss = instance.call(yTrue, yPred, sampleWeight); + float expected = 0f; + testSession.evaluate(expected, loss); + } + } - @Test - public void testTimestepWeighted() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - SquaredHinge instance = new SquaredHinge(tf, Reduction.AUTO); - float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; - float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; - Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4, 1))); - Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); - float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; - Operand sampleWeight = tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); - Operand loss = instance.call(yTrue, yPred, sampleWeight); + @Test + public void testTimestepWeighted() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + SquaredHinge instance = new SquaredHinge(tf, Reduction.AUTO); + float[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; + Operand yTrue = + tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4, 1))); + Operand yPred = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4, 1))); + float[] sampleArray = {3f, 6f, 5f, 0f, 4f, 2f, 1f, 3f}; + Operand sampleWeight = + tf.reshape(tf.constant(sampleArray), tf.constant(Shape.of(2, 4))); + Operand loss = instance.call(yTrue, yPred, sampleWeight); - float expected = 1.54250000f; - testSession.evaluate(expected, loss); - } - } + float expected = 1.54250000f; + testSession.evaluate(expected, loss); + } + } } From 4d9fd24b809fd4141e61ca504b32b251a993cf8c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 16 Nov 2020 18:36:54 -0500 Subject: [PATCH 24/26] Fix copyright to attribute TensorFlow Authors. --- .../framework/losses/BinaryCrossentropy.java | 29 +++++++++---------- .../losses/CategoricalCrossentropy.java | 29 +++++++++---------- .../framework/losses/CategoricalHinge.java | 29 +++++++++---------- .../framework/losses/CosineSimilarity.java | 29 +++++++++---------- .../tensorflow/framework/losses/Hinge.java | 29 +++++++++---------- .../tensorflow/framework/losses/Huber.java | 29 +++++++++---------- .../framework/losses/KLDivergence.java | 29 +++++++++---------- .../tensorflow/framework/losses/LogCosh.java | 29 +++++++++---------- .../org/tensorflow/framework/losses/Loss.java | 29 +++++++++---------- .../tensorflow/framework/losses/Losses.java | 29 +++++++++---------- .../framework/losses/MeanAbsoluteError.java | 29 +++++++++---------- .../losses/MeanAbsolutePercentageError.java | 29 +++++++++---------- .../framework/losses/MeanSquaredError.java | 29 +++++++++---------- .../losses/MeanSquaredLogarithmicError.java | 29 +++++++++---------- .../tensorflow/framework/losses/Poisson.java | 29 +++++++++---------- .../framework/losses/Reduction.java | 29 +++++++++---------- .../losses/SparseCategoricalCrossentropy.java | 29 +++++++++---------- .../framework/losses/SquaredHinge.java | 29 +++++++++---------- .../framework/losses/impl/LossTuple.java | 29 +++++++++---------- .../framework/losses/impl/LossesHelper.java | 29 +++++++++---------- .../losses/BinaryCrossentropyTest.java | 29 +++++++++---------- .../losses/CategoricalCrossentropyTest.java | 29 +++++++++---------- .../losses/CategoricalHingeTest.java | 29 +++++++++---------- .../losses/CosineSimilarityTest.java | 29 +++++++++---------- .../framework/losses/HingeTest.java | 29 +++++++++---------- .../framework/losses/HuberTest.java | 29 +++++++++---------- .../framework/losses/KLDivergenceTest.java | 29 +++++++++---------- .../framework/losses/LogCoshTest.java | 29 +++++++++---------- .../losses/MeanAbsoluteErrorTest.java | 29 +++++++++---------- .../MeanAbsolutePercentageErrorTest.java | 29 +++++++++---------- .../losses/MeanSquaredErrorTest.java | 29 +++++++++---------- .../MeanSquaredLogarithmicErrorTest.java | 29 +++++++++---------- .../framework/losses/PoissonTest.java | 29 +++++++++---------- .../SparseCategoricalCrossentropyTest.java | 29 +++++++++---------- .../framework/losses/SquaredHingeTest.java | 29 +++++++++---------- 35 files changed, 490 insertions(+), 525 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 6e3e7fa321c..5346e9acb20 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 522446bec4d..bd92f20a3e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index c60628fd22e..f592c19f8bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index f6dc5c7e5fd..137c7025c04 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 31b155a76a2..5fdfd4c9b96 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 69487405148..6d3e3f0c2ac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 52af31c820b..8cf3db8d518 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 97ca2b99cff..1669669a768 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index b9a08ad2b0f..a014d41f2cb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 3b61b0f8eab..7a633ede2bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.DataType; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 3de27026944..a2d5d5f8efc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 009c22f1b8a..49133df610b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index e218bb38616..2a6c2be885e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index c673ca292d6..2604e226b81 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 6e61502fca7..c43be4f2821 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java index 26e151c1e81..87ea43c6c3a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Reduction.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 3086b6200a2..49c31aadc84 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 25c75e7732d..e97a1e61138 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 76ce6ff4d43..2104937a979 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses.impl; import org.tensorflow.Operand; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 3dc1c66be4d..463296a1f50 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses.impl; import org.tensorflow.DataType; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java index 0928a4ad10c..d2128b80839 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/BinaryCrossentropyTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 36af5bd51ad..13b287de3cd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java index 8923a299774..b0d0442b3c7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalHingeTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java index dea995ad987..8350d1403ed 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CosineSimilarityTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java index 3e86b303b83..4770511207e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HingeTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java index 9ab11e0cf0e..d1751f223a1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/HuberTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java index 7875a63f1d9..d57b61b18dd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/KLDivergenceTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java index 1d06669b731..c4347b3fccb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/LogCoshTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java index f069091b7a2..3498c6d53aa 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsoluteErrorTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Assertions; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java index debf6b967b9..7816a8a288a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanAbsolutePercentageErrorTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java index b2c7ee46837..1a971f0492b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredErrorTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Assertions; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java index 71c840f7284..558f9c84659 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicErrorTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Assertions; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java index e156c786a81..55c59ca5ac6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/PoissonTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java index 26f4050a52d..a6a0ff35c78 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropyTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java index 58da30dc598..57a012bbe9d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/SquaredHingeTest.java @@ -1,18 +1,17 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ package org.tensorflow.framework.losses; import org.junit.jupiter.api.Test; From d56d8d9dbfb8d1d8cfc4b829ea1e3b3bfe93478d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 16 Nov 2020 18:45:07 -0500 Subject: [PATCH 25/26] Fix typo on broadcast in JavaDoc --- .../main/java/org/tensorflow/framework/losses/SquaredHinge.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index e97a1e61138..182ce592e55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -113,7 +113,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the * corresponding element in the SampleWeights vector. If the shape of SampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. From 744e32463c4aa8def4456fac4bcec53536a04fa4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 16 Nov 2020 18:46:12 -0500 Subject: [PATCH 26/26] Fix typo on broadcast in JavaDoc --- .../org/tensorflow/framework/losses/BinaryCrossentropy.java | 2 +- .../tensorflow/framework/losses/CategoricalCrossentropy.java | 2 +- .../src/main/java/org/tensorflow/framework/losses/Loss.java | 2 +- .../framework/losses/SparseCategoricalCrossentropy.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index 5346e9acb20..effdf990f71 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -198,7 +198,7 @@ public BinaryCrossentropy( * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the * corresponding element in the SampleWeights vector. If the shape of SampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index bd92f20a3e2..7701ebfb806 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -237,7 +237,7 @@ public CategoricalCrossentropy( * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the * corresponding element in the SampleWeights vector. If the shape of SampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index a014d41f2cb..ae33d5dfa37 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -78,7 +78,7 @@ public Operand call(Operand labels, * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the * corresponding element in the SampleWeights vector. If the shape of SampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 49c31aadc84..5586a4da889 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -186,7 +186,7 @@ public SparseCategoricalCrossentropy( * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor * of size [batch_size], then the total loss for each sample of the batch is rescaled by the * corresponding element in the SampleWeights vector. If the shape of SampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * [batch_size, d0, .. dN-1] (or can be broadcast to this shape), then each loss element of * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss.