diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 870f4972c3c..273d2bb6c5e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -31,7 +31,6 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; @@ -88,7 +87,7 @@ * * @param The data type for the metric result */ -public class AUC extends Metric { +public class AUC extends BaseMetric { /** Default Fuzz factor. */ public static final float EPSILON = 1e-7f; @@ -109,14 +108,12 @@ public class AUC extends Metric { private final String falsePositivesName; private final String trueNegativesName; private final String falseNegativesName; - private final Map> initializers = new HashMap<>(); private final Class type; - + private final Zeros zeros = new Zeros<>(); /** The size of the label dimension. */ private Integer numLabels; private Operand labelWeights; - /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * @@ -124,7 +121,6 @@ public class AUC extends Metric { * class dimension within each example. */ private Variable truePositives; - /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * @@ -132,7 +128,6 @@ public class AUC extends Metric { * class dimension within each example. */ private Variable falsePositives; - /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * @@ -140,7 +135,6 @@ public class AUC extends Metric { * class dimension within each example. */ private Variable trueNegatives; - /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * @@ -149,7 +143,8 @@ public class AUC extends Metric { */ private Variable falseNegatives; - private boolean initialized; + private Shape variableShape; + private Shape shape; /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, @@ -157,14 +152,12 @@ public class AUC extends Metric { * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, * {@code false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, long seed, Class type) { + public AUC(long seed, Class type) { this( - tf, null, DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, @@ -182,15 +175,13 @@ public AUC(Ops tf, long seed, Class type) { * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, {@code * false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, String name, long seed, Class type) { + public AUC(String name, long seed, Class type) { this( - tf, name, DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, @@ -208,16 +199,14 @@ public AUC(Ops tf, String name, long seed, Class type) { * summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code null} * for labelWeights. * - * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, int numThresholds, long seed, Class type) { + public AUC(int numThresholds, long seed, Class type) { this( - tf, null, numThresholds, AUCCurve.ROC, @@ -235,16 +224,14 @@ public AUC(Ops tf, int numThresholds, long seed, Class type) { * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and {@code * null} for labelWeights. * - * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, float[] thresholds, long seed, Class type) { + public AUC(float[] thresholds, long seed, Class type) { this( - tf, null, DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, @@ -261,7 +248,6 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, * {@code false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. @@ -269,9 +255,8 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { + public AUC(String name, int numThresholds, long seed, Class type) { this( - tf, name, numThresholds, AUCCurve.ROC, @@ -289,7 +274,6 @@ public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, {@code false} for multiLabel, and * {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. @@ -297,9 +281,8 @@ public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { + public AUC(String name, float[] thresholds, long seed, Class type) { this( - tf, name, DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, @@ -316,7 +299,6 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and {@code * null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. @@ -326,9 +308,8 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Class type) { + public AUC(String name, int numThresholds, AUCCurve curve, long seed, Class type) { this( - tf, name, numThresholds, curve, @@ -345,7 +326,6 @@ public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Cl * AUCSummationMethod#INTERPOLATION} for the summation method, {@link #DEFAULT_NUM_THRESHOLDS} num * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. @@ -355,9 +335,8 @@ public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Cl * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, Class type) { + public AUC(String name, float[] thresholds, AUCCurve curve, long seed, Class type) { this( - tf, name, DEFAULT_NUM_THRESHOLDS, curve, @@ -374,7 +353,6 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, * {@code false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -383,9 +361,8 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) { + public AUC(int numThresholds, AUCCurve curve, long seed, Class type) { this( - tf, null, numThresholds, curve, @@ -402,7 +379,6 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, and * {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -411,9 +387,8 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. */ - public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) { + public AUC(float[] thresholds, AUCCurve curve, long seed, Class type) { this( - tf, null, DEFAULT_NUM_THRESHOLDS, curve, @@ -429,7 +404,6 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -440,13 +414,12 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) * @param type the data type for the confusion matrix variables. */ public AUC( - Ops tf, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class type) { - this(tf, null, numThresholds, curve, summationMethod, null, false, null, seed, type); + this(null, numThresholds, curve, summationMethod, null, false, null, seed, type); } /** @@ -454,7 +427,6 @@ public AUC( * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} for * labelWeights. * - * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -465,30 +437,18 @@ public AUC( * @param type the data type for the confusion matrix variables. */ public AUC( - Ops tf, float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class type) { - this( - tf, - null, - DEFAULT_NUM_THRESHOLDS, - curve, - summationMethod, - thresholds, - false, - null, - seed, - type); + this(null, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type); } /** * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code false} * for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. @@ -500,21 +460,19 @@ public AUC( * @param type the data type for the confusion matrix variables. */ public AUC( - Ops tf, String name, int numThresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class type) { - this(tf, name, numThresholds, curve, summationMethod, null, false, null, seed, type); + this(name, numThresholds, curve, summationMethod, null, false, null, seed, type); } /** * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, {@code * false} for multiLabel, and {@code null} for labelWeights. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. @@ -526,30 +484,18 @@ public AUC( * @param type the data type for the confusion matrix variables. */ public AUC( - Ops tf, String name, float[] thresholds, AUCCurve curve, AUCSummationMethod summationMethod, long seed, Class type) { - this( - tf, - name, - DEFAULT_NUM_THRESHOLDS, - curve, - summationMethod, - thresholds, - false, - null, - seed, - type); + this(name, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type); } /** * Creates an AUC (Area under the curve) metric. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. * @param numThresholds the number of thresholds to use when discretizing the roc curve. This * includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. @@ -576,7 +522,6 @@ public AUC( * a threshold value is less than 0 or greater than 1. */ public AUC( - Ops tf, String name, int numThresholds, AUCCurve curve, @@ -586,7 +531,7 @@ public AUC( Operand labelWeights, long seed, Class type) { - super(tf, name == null ? DEFAULT_NAME : name, seed); + super(name == null ? DEFAULT_NAME : name, seed); truePositivesName = getVariableName(TRUE_POSITIVES); falsePositivesName = getVariableName(FALSE_POSITIVES); trueNegativesName = getVariableName(TRUE_NEGATIVES); @@ -630,82 +575,71 @@ public AUC( // Handle multilabel arguments. - if (labelWeights != null) { - // assert that labelWeights are non-negative. - - this.labelWeights = labelWeights; - Op checks = - tf.withSubScope("AUC") - .assertThat( - tf.math.greaterEqual(labelWeights, cast(tf, tf.constant(0), labelWeights.type())), - Collections.singletonList( - tf.constant("All values of labelWeights must be non-negative."))); - - Ops ltf = - tf.withSubScope("updateState").withControlDependencies(Collections.singletonList(checks)); - - this.labelWeights = ltf.identity(this.labelWeights); - } + this.labelWeights = labelWeights; if (multiLabel) { numLabels = null; } } - /** - * Initialize truePositives, falsePositives, trueNegatives, and falseNegatives variables, given - * the shape of the data. - * - * @param shape the prediction shape if called from updateState, otherwise null - */ - @SuppressWarnings("unchecked") - private Map> build(Shape shape) { - Shape variableShape; - if (initialized) { - return Collections.EMPTY_MAP; - } - Ops tf = getTF(); - - if (isMultiLabel()) { - if (shape == null) { - throw new IllegalArgumentException("For multiLabel, a shape must be provided"); + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (shape != null && !isInitialized()) { + setTF(tf); + if (labelWeights != null) { + // assert that labelWeights are non-negative. + + Op checks = + tf.withSubScope("AUC") + .assertThat( + tf.math.greaterEqual( + labelWeights, cast(tf, tf.constant(0), labelWeights.type())), + Collections.singletonList( + tf.constant("All values of labelWeights must be non-negative."))); + + Ops ltf = + tf.withSubScope("updateState") + .withControlDependencies(Collections.singletonList(checks)); + + this.labelWeights = ltf.identity(this.labelWeights); + } + if (isMultiLabel()) { + if (shape == null) { + throw new IllegalArgumentException("For multiLabel, a shape must be provided"); + } + if (shape.numDimensions() != 2) + throw new IllegalArgumentException( + String.format( + "labels must have rank=2 when multiLabel is true. Found rank %d.", + shape.numDimensions())); + numLabels = (int) shape.size(1); + variableShape = Shape.of(numThresholds, numLabels); + } else { + variableShape = Shape.of(numThresholds); } - if (shape.numDimensions() != 2) - throw new IllegalArgumentException( - String.format( - "labels must have rank=2 when multiLabel is true. Found rank %d.", - shape.numDimensions())); - numLabels = (int) shape.size(1); - variableShape = Shape.of(numThresholds, numLabels); - } else { - variableShape = Shape.of(numThresholds); - } - // Create metric variables - Zeros zeros = new Zeros<>(); - Operand zero = zeros.call(tf, tf.constant(variableShape), type); - if (truePositives == null) { - truePositives = tf.withName(getTruePositivesName()).withInitScope().variable(zero); - initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero)); - } + // Create metric variables - if (falsePositives == null) { - falsePositives = tf.withName(getFalsePositivesName()).withInitScope().variable(zero); - initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, tf.assign(falsePositives, zero)); - } + Operand zero = zeros.call(tf, tf.constant(variableShape), type); + if (truePositives == null) { + truePositives = tf.withName(getTruePositivesName()).withInitScope().variable(zero); + } - if (trueNegatives == null) { - trueNegatives = tf.withName(getTrueNegativesName()).withInitScope().variable(zero); - initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, tf.assign(trueNegatives, zero)); - } + if (falsePositives == null) { + falsePositives = tf.withName(getFalsePositivesName()).withInitScope().variable(zero); + } - if (falseNegatives == null) { - falseNegatives = tf.withName(getFalseNegativesName()).withInitScope().variable(zero); - initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero)); - } + if (trueNegatives == null) { + trueNegatives = tf.withName(getTrueNegativesName()).withInitScope().variable(zero); + } - initialized = true; - return initializers; + if (falseNegatives == null) { + falseNegatives = tf.withName(getFalseNegativesName()).withInitScope().variable(zero); + } + setInitialized(true); + } } /** @@ -721,20 +655,20 @@ private Map> build(Shape shape) { * @return a List of Operations to update the metric state */ @Override - @SuppressWarnings("unchecked") public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Ops tf = getTF(); + if (shape == null) { + shape = predictions.shape(); + } + init(tf); Operand tLabels = cast(tf, labels, type); Operand tPredictions = cast(tf, predictions, type); Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; List updateOperations = new ArrayList<>(); - Map> varInitializers = Collections.EMPTY_MAP; - if (!initialized) { - varInitializers = build(tPredictions.shape()); - } + if (isMultiLabel() || getLabelWeights() != null) { // labels should have shape (number of examples, number of labels). List> symbols = new ArrayList<>(); @@ -767,7 +701,6 @@ public List updateStateList( MetricsHelper.updateConfusionMatrixVariables( tf, confusionMatrix, - varInitializers, tLabels, tPredictions, tf.constant(thresholds), @@ -785,8 +718,8 @@ public List updateStateList( * @param input the input * @return the input with all positive numbers. */ - private Operand positive(Operand input) { - return getTF().math.maximum(input, cast(getTF(), getTF().constant(0), input.type())); + private Operand positive(Ops tf, Operand input) { + return tf.math.maximum(input, cast(tf, tf.constant(0), input.type())); } /** @@ -795,8 +728,8 @@ private Operand positive(Operand input) { * @param input the input * @return the truth value of whether {@code input > 0}, element-wise. */ - private Operand isPositive(Operand input) { - return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type())); + private Operand isPositive(Ops tf, Operand input) { + return tf.math.greater(input, cast(tf, tf.constant(0), input.type())); } /** @@ -807,9 +740,8 @@ private Operand isPositive(Operand input) { * @param size the size of the slice * @return the slice */ - private Operand slice(Operand input, int begin, int size) { - return getTF() - .slice(input, getTF().constant(new int[] {begin}), getTF().constant(new int[] {size})); + private Operand slice(Ops tf, Operand input, int begin, int size) { + return tf.slice(input, tf.constant(new int[] {begin}), tf.constant(new int[] {size})); } /** @@ -863,38 +795,37 @@ private Operand slice(Operand input, int begin, int size) { * @see The Relationship Between * Precision-Recall and ROC Curves - Davis & Goadrich 2006 */ - private Operand interpolatePRAuc() { + private Operand interpolatePRAuc(Ops tf) { // truePositives[:self.numThresholds - 1] - Ops tf = getTF(); - Operand tp0 = slice(truePositives, 0, getNumThresholds() - 1); + Operand tp0 = slice(tf, truePositives, 0, getNumThresholds() - 1); // truePositives[1:] - Operand tp1 = slice(truePositives, 1, -1); + Operand tp1 = slice(tf, truePositives, 1, -1); Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); - Operand p0 = slice(p, 0, getNumThresholds() - 1); - Operand p1 = slice(p, 1, -1); + Operand p0 = slice(tf, p, 0, getNumThresholds() - 1); + Operand p1 = slice(tf, p, 1, -1); Operand dP = tf.math.sub(p0, p1); - Operand precisionSlope = tf.math.divNoNan(dTP, positive(dP)); + Operand precisionSlope = tf.math.divNoNan(dTP, positive(tf, dP)); Operand intercept = tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); Operand safePRatio = tf.select( - tf.math.logicalAnd(isPositive(p0), isPositive(p1)), - tf.math.divNoNan(p0, positive(p1)), + tf.math.logicalAnd(isPositive(tf, p0), isPositive(tf, p1)), + tf.math.divNoNan(p0, positive(tf, p1)), tf.onesLike(p1)); - Operand fn1 = slice(falseNegatives, 1, -1); + Operand fn1 = slice(tf, falseNegatives, 1, -1); Operand aucTotalPos = tf.math.mul( precisionSlope, tf.math.add(dTP, tf.math.mul(intercept, tf.math.log(safePRatio)))); - Operand prAucIncrement = tf.math.divNoNan(aucTotalPos, positive(tf.math.add(tp1, fn1))); + Operand prAucIncrement = tf.math.divNoNan(aucTotalPos, positive(tf, tf.math.add(tp1, fn1))); if (isMultiLabel()) { Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0)); @@ -914,13 +845,12 @@ private Operand interpolatePRAuc() { /** {@inheritDoc} */ @Override - public Operand result() { - + public Operand result(Ops tf, Class resultType) { + init(tf); if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { // This use case is different and is handled separately. - return interpolatePRAuc(); + return cast(tf, interpolatePRAuc(tf), resultType); } - Ops tf = getTF(); Operand x; Operand y; Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); @@ -940,19 +870,22 @@ public Operand result() { // Find the rectangle heights based on `summationMethod`. // y[:self.numThresholds - 1] - Operand ySlice1 = slice(y, 0, getNumThresholds() - 1); + Operand ySlice1 = slice(tf, y, 0, getNumThresholds() - 1); // y[1:] - Operand ySlice2 = slice(y, 1, -1); + Operand ySlice2 = slice(tf, y, 1, -1); Operand heights; switch (getSummationMethod()) { case INTERPOLATION: + //noinspection SuspiciousNameCombination heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); break; case MINORING: + //noinspection SuspiciousNameCombination heights = tf.math.minimum(ySlice1, ySlice2); break; case MAJORING: + //noinspection SuspiciousNameCombination heights = tf.math.maximum(ySlice1, ySlice2); break; default: @@ -962,33 +895,51 @@ public Operand result() { if (isMultiLabel()) { Operand riemannTerms = - tf.math.mul(tf.math.sub(slice(x, 0, getNumThresholds() - 1), slice(x, 1, -1)), heights); + tf.math.mul( + tf.math.sub(slice(tf, x, 0, getNumThresholds() - 1), slice(tf, x, 1, -1)), heights); Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); if (getLabelWeights() == null) { - return MetricsHelper.mean(tf, byLabelAuc); + return cast(tf, MetricsHelper.mean(tf, byLabelAuc), resultType); } else { // Weighted average of the label AUCs. - return tf.math.divNoNan( - tf.reduceSum( - tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), - tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); + return cast( + tf, + tf.math.divNoNan( + tf.reduceSum( + tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), + tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))), + resultType); } } else { - Operand slice1 = slice(x, 0, getNumThresholds() - 1); - Operand slice2 = slice(x, 1, -1); + Operand slice1 = slice(tf, x, 0, getNumThresholds() - 1); + Operand slice2 = slice(tf, x, 1, -1); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); - return tf.reduceSum(operand, allAxes(tf, operand)); + return cast(tf, tf.reduceSum(operand, allAxes(tf, operand)), resultType); } } /** {@inheritDoc} */ @Override - public Op resetStates() { - List updateOperations = new ArrayList<>(initializers.values()); - return getTF().withSubScope("resetStates").withControlDependencies(updateOperations).noOp(); + public Op resetStates(Ops tf) { + init(tf); + Operand zero = zeros.call(tf, tf.constant(variableShape), type); + List controlList = new ArrayList<>(); + if (truePositives != null) { + controlList.add(tf.assign(truePositives, zero)); + } + if (falsePositives != null) { + controlList.add(tf.assign(falsePositives, zero)); + } + if (trueNegatives != null) { + controlList.add(tf.assign(trueNegatives, zero)); + } + if (falseNegatives != null) { + controlList.add(tf.assign(falseNegatives, zero)); + } + return tf.withControlDependencies(controlList).noOp(); } /** @return the numThresholds */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 14f45020739..f0324e4daa5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -16,10 +16,11 @@ import static org.tensorflow.framework.utils.CastHelper.cast; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.framework.metrics.impl.MetricsHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -36,31 +37,29 @@ * * @param The data type for the metric result */ -public class Accuracy extends MeanMetricWrapper implements LossMetric { +public class Accuracy extends MeanBaseMetricWrapper implements LossMetric { /** * Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Accuracy(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + public Accuracy(long seed, Class type) { + this(null, seed, type); } /** * Creates an Accuracy Metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Accuracy(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public Accuracy(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } @@ -68,18 +67,24 @@ public Accuracy(Ops tf, String name, long seed, Class type) { * Calculates how often predictions equals labels. {@code labels} and {@code predictions} must * have compatible shapes, see {@link Shape @isCompatibleWith}. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels * @param predictions the predictions - * @throws IllegalArgumentException if predictions and labels shapes are not compatible. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return the loss */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); LossTuple tuple = - MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions); + MetricsHelper.raggedAssertCompatibleAndGetFlatValues(tf, tLabels, tPredictions); tLabels = tuple.getLabels(); tPredictions = tuple.getTarget(); @@ -91,6 +96,6 @@ public Operand call( } // cast TBool to result type - return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + return cast(tf, tf.math.equal(tLabels, tPredictions), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BaseMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BaseMetric.java new file mode 100644 index 00000000000..0605f0f5ef5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BaseMetric.java @@ -0,0 +1,259 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import java.util.Collections; +import java.util.List; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Base class for Metrics */ +public abstract class BaseMetric implements Metric { + + /** The seed for random number generation */ + private final long seed; + + private String name; + + private boolean initialized; + + private Ops tf; + + /** + * Creates a Metric with a name of {@link Class#getSimpleName()} + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected BaseMetric(long seed) { + this(null, seed); + } + + /** + * Creates a Metric + * + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected BaseMetric(String name, long seed) { + + this.seed = seed; + this.name = name != null ? name : this.getClass().getSimpleName(); + } + + /** + * Creates a List of Operations to update the metric state based on input values. + * + *

This is an empty implementation that should be overridden in a subclass, if needed. + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to the values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return a List of Operations to update the metric state + */ + @SuppressWarnings({"unchecked", "unused"}) + @Override + public List updateStateList( + Ops tf, Operand values, Operand sampleWeights) { + checkIsGraph(tf); + return Collections.EMPTY_LIST; + } + + /** + * Creates a List of Operations to update the metric state based on labels and predictions. + * + *

This is an empty implementation that should be overridden in a subclass, if needed. + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights sample weights to be applied to the metric values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return a List of Operations to update the metric state + */ + @Override + @SuppressWarnings({"unchecked", "unused"}) + public List updateStateList( + Ops tf, + Operand labels, + Operand predictions, + Operand sampleWeights) { + checkIsGraph(tf); + return Collections.EMPTY_LIST; + } + + /** + * Creates a NoOp Operation with control dependencies to update the metric state + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to the values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return the Operation to update the metric state + */ + public final Op updateState( + Ops tf, Operand values, Operand sampleWeights) { + checkIsGraph(tf); + List controlOps = updateStateList(tf, values, sampleWeights); + return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } + + /** + * Creates a NoOp Operation with control dependencies to update the metric state + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights sample weights to be applied to the metric values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return the Operation to update the metric state + */ + public final Op updateState( + Ops tf, + Operand labels, + Operand predictions, + Operand sampleWeights) { + List controlOps = updateStateList(tf, labels, predictions, sampleWeights); + return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } + + /** + * Calls update state once, followed by a call to get the result + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to the values, may be null. + * @param The data type for the metric result + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return the result, possibly with control dependencies + */ + @Override + public final Operand callOnce( + Ops tf, + Operand values, + Operand sampleWeights, + Class type) { + checkIsGraph(tf); + List controlOps = updateStateList(tf, values, sampleWeights); + Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); + return ltf.identity(result(ltf, type)); + } + + /** + * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. + * + * @param varName the base name for the variable + * @return the formatted variable name + */ + protected String getVariableName(String varName) { + return String.format("%s_%s", this.name, varName); + } + + /** + * The name for this metric. Defaults to {@link Class#getSimpleName()}. + * + *

Gets the name of this metric. + * + * @return the name of this metric + */ + public String getName() { + return name; + } + + /** + * Sets the metric name + * + * @param name the metric name + */ + public void setName(String name) { + this.name = name; + } + + /** + * Gets the random number generator seed value + * + * @return the random number generator seed value + */ + public long getSeed() { + return seed; + } + + /** + * Initialize the TensorFlow Ops + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @throws IllegalArgumentException if the TensorFlow Ops does not have a Graph environment, + */ + protected abstract void init(Ops tf); + + /** + * Gets the TensorFlow Ops for this metric + * + * @return the TensorFlow Ops for this metric. + */ + protected Ops getTF() { + return tf; + } + + /** + * Sets the TensorFlow Ops for this metric. + * + *

This should be set from the {@link #init(Ops)} implementation. + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + */ + protected void setTF(Ops tf) { + checkIsGraph(tf); + this.tf = tf; + } + + /** + * Checks whether the Metric is initialized or not. + * + * @return true if the Metric has been initialized. + */ + public boolean isInitialized() { + return initialized; + } + + /** + * Sets the initialized indicator + * + * @param initialized the initialized indicator + */ + protected void setInitialized(boolean initialized) { + this.initialized = initialized; + } + + /** + * Checks if the TensorFlow Ops encapsulates a {@link Graph} environment. + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. + */ + protected void checkIsGraph(Ops tf) { + if (!tf.scope().env().isGraph()) { + throw new IllegalArgumentException( + "The Ops environment is not a Graph, Graph is required for metrics."); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index c27bf1b2acf..b230c76b111 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -16,9 +16,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -33,8 +34,8 @@ * * @param The data type for the metric result */ -public class BinaryAccuracy extends MeanMetricWrapper - implements LossMetric { +public class BinaryAccuracy extends MeanBaseMetricWrapper + implements LossMetric { /** the default threshold value for deciding whether prediction values are 1 or 0 */ public static final float DEFAULT_THRESHOLD = 0.5f; @@ -45,40 +46,37 @@ public class BinaryAccuracy extends MeanMetricWrapper * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name and * {@link #DEFAULT_THRESHOLD} for the threshold value. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public BinaryAccuracy(Ops tf, long seed, Class type) { - this(tf, null, DEFAULT_THRESHOLD, seed, type); + public BinaryAccuracy(long seed, Class type) { + this(null, DEFAULT_THRESHOLD, seed, type); } /** * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param threshold a threshold for deciding whether prediction values are 1 or 0 * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public BinaryAccuracy(Ops tf, float threshold, long seed, Class type) { - this(tf, null, threshold, seed, type); + public BinaryAccuracy(float threshold, long seed, Class type) { + this(null, threshold, seed, type); } /** * Creates a BinaryAccuracy Metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param threshold a threshold for deciding whether prediction values are 1 or 0 * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class type) { - super(tf, name, seed, type); + public BinaryAccuracy(String name, float threshold, long seed, Class type) { + super(name, seed, type); this.threshold = threshold; setLoss(this); } @@ -86,19 +84,24 @@ public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class /** * Calculates how often predictions match binary labels. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Binary accuracy values. shape = {@code [batch_size, d0, .. dN-1]} */ @Override - public Operand call( - Operand labels, Operand predictions) { - - Operand tPredictions = cast(getTF(), predictions, getResultType()); - Operand thresholdCast = cast(getTF(), getTF().constant(threshold), getResultType()); - tPredictions = - cast(getTF(), getTF().math.greater(tPredictions, thresholdCast), getResultType()); - Operand tLabels = cast(getTF(), labels, getResultType()); - return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tPredictions = cast(tf, predictions, getInternalType()); + Operand thresholdCast = cast(tf, tf.constant(threshold), getInternalType()); + tPredictions = cast(tf, tf.math.greater(tPredictions, thresholdCast), getInternalType()); + Operand tLabels = cast(tf, labels, getInternalType()); + return cast(tf, tf.math.equal(tLabels, tPredictions), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 57a6f75375d..d306a00f70d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -14,15 +14,16 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -31,16 +32,31 @@ * * @param The data type for the metric result */ -public class BinaryCrossentropy extends MeanMetricWrapper - implements LossMetric { +public class BinaryCrossentropy extends MeanBaseMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; + /** + * Creates a BinaryCrossentropy metric where name is {@link Class#getSimpleName()}. + * + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. + * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public BinaryCrossentropy(boolean fromLogits, float labelSmoothing, long seed, Class type) { + this(null, fromLogits, labelSmoothing, seed, type); + } /** * Creates a BinaryCrossentropy metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. @@ -53,8 +69,8 @@ public class BinaryCrossentropy extends MeanMetricWrapper * @param type the type for the variables and result */ public BinaryCrossentropy( - Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { - super(tf, name, seed, type); + String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { + super(name, seed, type); setLoss(this); this.fromLogits = fromLogits; this.labelSmoothing = labelSmoothing; @@ -63,16 +79,23 @@ public BinaryCrossentropy( /** * Computes the binary crossentropy loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, has the same shape as predictions and shape = {@code * [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Binary crossentropy loss value. shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.binaryCrossentropy(tf, tLabels, tPredictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index 70dfebc508d..19547612503 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -16,9 +16,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.op.core.OneHot; import org.tensorflow.types.TInt64; @@ -43,32 +44,30 @@ * * @param The data type for the metric result */ -public class CategoricalAccuracy extends MeanMetricWrapper - implements LossMetric { +public class CategoricalAccuracy extends MeanBaseMetricWrapper + implements LossMetric { /** * Creates a CategoricalAccuracy metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public CategoricalAccuracy(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + public CategoricalAccuracy(long seed, Class type) { + this(null, seed, type); } /** * Creates a CategoricalAccuracy metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public CategoricalAccuracy(String name, long seed, Class type) { + super(name, seed, type); super.setLoss(this); } @@ -79,16 +78,23 @@ public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { * rather than as labels. If necessary, use {@link Ops#oneHot} to expand {@code labels} as a * vector. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels One-hot ground truth values. * @param predictions tThe prediction values. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Categorical accuracy values. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand trueMax = getTF().math.argMax(labels, getTF().constant(-1)); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand trueMax = tf.math.argMax(labels, tf.constant(-1)); - Operand predMax = getTF().math.argMax(predictions, getTF().constant(-1)); - return cast(getTF(), getTF().math.equal(trueMax, predMax), getResultType()); + Operand predMax = tf.math.argMax(predictions, tf.constant(-1)); + return cast(tf, tf.math.equal(trueMax, predMax), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index fa7c1a1a626..0390660de1b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -16,10 +16,11 @@ import static org.tensorflow.framework.utils.CastHelper.cast; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -33,12 +34,54 @@ * * @param The data type for the metric result */ -public class CategoricalCrossentropy extends MeanMetricWrapper - implements LossMetric { +public class CategoricalCrossentropy extends MeanBaseMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; private final int axis; + /** + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the + * labels and predictions using {@link Class#getSimpleName()} for the metric name + * + *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. + * + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to + * a probability distribution. + * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CategoricalCrossentropy( + boolean fromLogits, float labelSmoothing, long seed, Class type) { + this(null, fromLogits, labelSmoothing, Losses.CHANNELS_LAST, seed, type); + } + + /** + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the + * labels and predictions using {@link Class#getSimpleName()} for the metric name. + * + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. + * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} means + * that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 } for label + * {@code 1} + * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} + * corresponds to data format {@code channels_last}, and {@code axis={@link + * Losses#CHANNELS_FIRST}} corresponds to data format {@code channels_first}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CategoricalCrossentropy( + boolean fromLogits, float labelSmoothing, int axis, long seed, Class type) { + this(null, fromLogits, labelSmoothing, axis, seed, type); + } /** * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the @@ -46,7 +89,6 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. @@ -59,15 +101,14 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param type the type for the variables and result */ public CategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { - this(tf, name, fromLogits, labelSmoothing, Losses.CHANNELS_LAST, seed, type); + String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { + this(name, fromLogits, labelSmoothing, Losses.CHANNELS_LAST, seed, type); } /** * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the * labels and predictions. * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. @@ -83,14 +124,8 @@ public CategoricalCrossentropy( * @param type the type for the variables and result */ public CategoricalCrossentropy( - Ops tf, - String name, - boolean fromLogits, - float labelSmoothing, - int axis, - long seed, - Class type) { - super(tf, name, seed, type); + String name, boolean fromLogits, float labelSmoothing, int axis, long seed, Class type) { + super(name, seed, type); setLoss(this); this.fromLogits = fromLogits; this.labelSmoothing = labelSmoothing; @@ -100,16 +135,23 @@ public CategoricalCrossentropy( /** * Computes the crossentropy loss between the labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, of one-hot true targets, same shape as predictions * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Categorical crossentropy loss value. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); return Losses.categoricalCrossentropy( - getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); + tf, tLabels, tPredictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 1f6d0fd002c..b7246e66790 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -14,49 +14,66 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper - implements LossMetric { +public class CategoricalHinge extends MeanBaseMetricWrapper + implements LossMetric { + /** + * Creates a CategoricalHinge metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CategoricalHinge(long seed, Class type) { + this(null, seed, type); + } /** * Creates a CategoricalHinge metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public CategoricalHinge(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public CategoricalHinge(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the categorical hinge metric between {@code labels} and @{code predictions}. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, labels values are expected to be 0 or 1. * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Categorical hinge loss values. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.categoricalHinge(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.categoricalHinge(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 230286a738f..c02caa4d8fd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -14,15 +14,16 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the cosine similarity metric between labels and predictions. * @@ -38,50 +39,83 @@ * @param The data type for the metric result. * @see Cosine Similarity */ -public class CosineSimilarity extends MeanMetricWrapper - implements LossMetric { +public class CosineSimilarity extends MeanBaseMetricWrapper + implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; + /** + * Creates a metric that computes the cosine similarity metric between labels and predictions with + * a default axis, {@link #DEFAULT_AXIS} and using {@link Class#getSimpleName()} for the metric + * name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CosineSimilarity(long seed, Class type) { + this(null, DEFAULT_AXIS, seed, type); + } + + /** + * Creates a metric that computes the cosine similarity metric between labels and predictions + * using {@link Class#getSimpleName()} for the metric name. + * + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CosineSimilarity(int axis, long seed, Class type) { + this(null, new int[] {axis}, seed, type); + } + /** + * Creates a CosineSimilarity metric using {@link Class#getSimpleName()} for the metric name. + * + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CosineSimilarity(int[] axis, long seed, Class type) { + this(null, axis, seed, type); + } /** * Creates a metric that computes the cosine similarity metric between labels and predictions with * a default axis, {@link #DEFAULT_AXIS} * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public CosineSimilarity(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_AXIS, seed, type); + public CosineSimilarity(String name, long seed, Class type) { + this(name, DEFAULT_AXIS, seed, type); } /** * Creates a metric that computes the cosine similarity metric between labels and predictions. * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { - this(tf, name, new int[] {axis}, seed, type); + public CosineSimilarity(String name, int axis, long seed, Class type) { + this(name, new int[] {axis}, seed, type); } /** * Creates a CosineSimilarity metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) { - super(tf, name, seed, type); + public CosineSimilarity(String name, int[] axis, long seed, Class type) { + super(name, seed, type); this.axis = axis; setLoss(this); } @@ -89,17 +123,24 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** * Computes the cosine similarity loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return the cosine similarity loss */ @Override - public Operand call( - Operand labels, Operand predictions) { + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.cosineSimilarity(tf, tLabels, tPredictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index 9f957ee6c17..6f121fd307f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -16,7 +16,6 @@ import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** @@ -37,19 +36,17 @@ public class FalseNegatives extends ConfusionMatrixConditionC * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name and a * default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalseNegatives(Ops tf, long seed, Class type) { - this(tf, null, DEFAULT_THRESHOLD, seed, type); + public FalseNegatives(long seed, Class type) { + this(null, DEFAULT_THRESHOLD, seed, type); } /** * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -58,14 +55,13 @@ public FalseNegatives(Ops tf, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalseNegatives(Ops tf, float threshold, long seed, Class type) { - this(tf, null, new float[] {threshold}, seed, type); + public FalseNegatives(float threshold, long seed, Class type) { + this(null, new float[] {threshold}, seed, type); } /** * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -74,27 +70,25 @@ public FalseNegatives(Ops tf, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalseNegatives(Ops tf, float[] thresholds, long seed, Class type) { - this(tf, null, thresholds, seed, type); + public FalseNegatives(float[] thresholds, long seed, Class type) { + this(null, thresholds, seed, type); } /** * Creates a FalseNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalseNegatives(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_THRESHOLD, seed, type); + public FalseNegatives(String name, long seed, Class type) { + this(name, DEFAULT_THRESHOLD, seed, type); } /** * Creates a FalseNegatives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -104,14 +98,13 @@ public FalseNegatives(Ops tf, String name, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalseNegatives(Ops tf, String name, float threshold, long seed, Class type) { - this(tf, name, new float[] {threshold}, seed, type); + public FalseNegatives(String name, float threshold, long seed, Class type) { + this(name, new float[] {threshold}, seed, type); } /** * Creates a FalseNegatives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -121,7 +114,7 @@ public FalseNegatives(Ops tf, String name, float threshold, long seed, Class * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalseNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { - super(tf, name, ConfusionMatrixEnum.FALSE_NEGATIVES, thresholds, seed, type); + public FalseNegatives(String name, float[] thresholds, long seed, Class type) { + super(name, ConfusionMatrixEnum.FALSE_NEGATIVES, thresholds, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index a3d585dea0f..a072c53fced 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -16,7 +16,6 @@ import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** @@ -37,19 +36,17 @@ public class FalsePositives extends ConfusionMatrixConditionC * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name and a * default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalsePositives(Ops tf, long seed, Class type) { - this(tf, null, DEFAULT_THRESHOLD, seed, type); + public FalsePositives(long seed, Class type) { + this(null, DEFAULT_THRESHOLD, seed, type); } /** * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -58,14 +55,13 @@ public FalsePositives(Ops tf, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalsePositives(Ops tf, float threshold, long seed, Class type) { - this(tf, null, new float[] {threshold}, seed, type); + public FalsePositives(float threshold, long seed, Class type) { + this(null, new float[] {threshold}, seed, type); } /** * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -74,27 +70,25 @@ public FalsePositives(Ops tf, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalsePositives(Ops tf, float[] thresholds, long seed, Class type) { - this(tf, null, thresholds, seed, type); + public FalsePositives(float[] thresholds, long seed, Class type) { + this(null, thresholds, seed, type); } /** * Creates a FalsePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalsePositives(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_THRESHOLD, seed, type); + public FalsePositives(String name, long seed, Class type) { + this(name, DEFAULT_THRESHOLD, seed, type); } /** * Creates a FalsePositives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -104,14 +98,13 @@ public FalsePositives(Ops tf, String name, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalsePositives(Ops tf, String name, float threshold, long seed, Class type) { - this(tf, name, new float[] {threshold}, seed, type); + public FalsePositives(String name, float threshold, long seed, Class type) { + this(name, new float[] {threshold}, seed, type); } /** * Creates a FalsePositives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -121,7 +114,7 @@ public FalsePositives(Ops tf, String name, float threshold, long seed, Class * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public FalsePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { - super(tf, name, ConfusionMatrixEnum.FALSE_POSITIVES, thresholds, seed, type); + public FalsePositives(String name, float[] thresholds, long seed, Class type) { + super(name, ConfusionMatrixEnum.FALSE_POSITIVES, thresholds, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index a2d110867b8..7ce3622099a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -14,48 +14,65 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the hinge loss metric between labels and predictions. * * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper implements LossMetric { +public class Hinge extends MeanBaseMetricWrapper implements LossMetric { + /** + * Creates a Hinge metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Hinge(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Hinge metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public Hinge(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public Hinge(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the hinge loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return the hinge loss between labels and predictions. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.hinge(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.hinge(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 155a891ccc2..b97a0a4355e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -14,34 +14,44 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper implements LossMetric { +public class KLDivergence extends MeanBaseMetricWrapper + implements LossMetric { + /** + * Creates a KLDivergence metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public KLDivergence(long seed, Class type) { + this(null, seed, type); + } /** * Creates a KLDivergence metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public KLDivergence(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public KLDivergence(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } @@ -53,10 +63,14 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { * @return the loss with shape {@code [batch_size, d0, .. dN-1]} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.kullbackLeiblerDivergence(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 786847d4b32..79e5e99d3c5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -14,49 +14,67 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper implements LossMetric { +public class LogCoshError extends MeanBaseMetricWrapper + implements LossMetric { + /** + * Creates a LogCoshError metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public LogCoshError(long seed, Class type) { + this(null, seed, type); + } /** * Creates a LogCoshError metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public LogCoshError(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public LogCoshError(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Calculates the Logarithm of the hyperbolic cosine of the prediction error. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels Ground truth values, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. - * @return Logcosh error values, shape = {@code [batch_size, d0, .. dN-1]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. + * @return the Logcosh error values, shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.logCosh(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.logCosh(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index 8902b329bcc..2fa85de9c10 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.metrics; import org.tensorflow.framework.metrics.impl.Reduce; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** @@ -24,17 +23,27 @@ * @param The data type for the metric result */ public class Mean extends Reduce { + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} and using + * {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Mean(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} * - * @param tf the TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected Mean(Ops tf, String name, long seed, Class type) { - super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); + public Mean(String name, long seed, Class type) { + super(name, MetricReduction.WEIGHTED_MEAN, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index b38d0a809e1..2fe18c132b6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -14,49 +14,66 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper - implements LossMetric { +public class MeanAbsoluteError extends MeanBaseMetricWrapper + implements LossMetric { + /** + * Creates a Mean Absolute Error metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanAbsoluteError(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Mean Absolute Error metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public MeanAbsoluteError(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the mean absolute error loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Mean absolute error values, shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.meanAbsoluteError(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 22bcd0ab0eb..bd777057210 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -14,49 +14,67 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError extends MeanMetricWrapper - implements LossMetric { +public class MeanAbsolutePercentageError extends MeanBaseMetricWrapper + implements LossMetric { + + /** + * Creates a Mean Absolute Error metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanAbsolutePercentageError(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Mean Absolute Error metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public MeanAbsolutePercentageError(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the mean absolute percentage error loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Mean absolute percentage error values, shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.meanAbsolutePercentageError(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 6495379c4c4..fba4c5e00cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.List; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.impl.MetricsHelper; @@ -43,99 +44,94 @@ * * @param The data type for the metric result */ -public class MeanIoU extends Metric { +public class MeanIoU extends BaseMetric { public static final String TOTAL_CONFUSION_MATRIX = "TOTAL_CONFUSION_MATRIX"; private final String totalCMName; private final Class type; + + private final Zeros zeros = new Zeros<>(); /** * The possible number of labels the prediction task can have. This value must be provided, since * a confusion matrix of dimension = [numClasses, numClasses] will be allocated. */ private final long numClasses; - private Variable totalConfusionMatrix; + private final Shape variableShape; private Assign initializer; + private Variable totalConfusionMatrix; /** * Creates a metric MeanIoU, using name as {@link Class#getSimpleName()} * - * @param tf the TensorFlow Ops * @param numClasses The possible number of labels the prediction task can have * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - protected MeanIoU(Ops tf, long numClasses, long seed, Class type) { - this(tf, null, numClasses, seed, type); + protected MeanIoU(long numClasses, long seed, Class type) { + this(null, numClasses, seed, type); } /** * Creates a MeanIoU metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param numClasses The possible number of labels the prediction task can have * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - protected MeanIoU(Ops tf, String name, long numClasses, long seed, Class type) { - super(tf, name, seed); + protected MeanIoU(String name, long numClasses, long seed, Class type) { + super(name, seed); this.type = type; this.totalCMName = this.getVariableName(TOTAL_CONFUSION_MATRIX); this.numClasses = numClasses; - init(); + variableShape = Shape.of(numClasses, numClasses); } - private void init() { - Shape variableShape = Shape.of(numClasses, numClasses); - - if (totalConfusionMatrix == null) { - Zeros zeros = new Zeros<>(); - totalConfusionMatrix = - getTF() - .withName(totalCMName) - .withInitScope() - .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); - initializer = - getTF() - .assign( - totalConfusionMatrix, zeros.call(getTF(), getTF().constant(variableShape), type)); + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized()) { + setTF(tf); + Operand zeroOp = zeros.call(tf, tf.constant(variableShape), type); + totalConfusionMatrix = tf.withName(totalCMName).withInitScope().variable(zeroOp); + initializer = tf.assign(totalConfusionMatrix, zeroOp); + setInitialized(true); } } /** {@inheritDoc} */ @Override - public Op resetStates() { - return initializer; - } - - /** - * Gets the initializer for the totalConfusionMatrix variable - * - * @return the initializer for the totalConfusionMatrix variable - */ - public Assign getInitializer() { - return initializer; + public Op resetStates(Ops tf) { + init(tf); + return tf.withName(totalCMName) + .assign(totalConfusionMatrix, zeros.call(tf, tf.constant(variableShape), type)); } /** * Accumulates the confusion matrix statistics. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the labels * @param predictions the predictions * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} * labels rank, and if the predictions size is not equal to the labels size */ @Override public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { + init(tf); if (sampleWeights != null) { long weightsRank = sampleWeights.shape().numDimensions(); long labelsRank = labels.shape().numDimensions(); @@ -158,30 +154,30 @@ public List updateStateList( labelsSize, predictionsSize)); } - Operand tLabels = cast(getTF(), labels, type); + Operand tLabels = cast(tf, labels, type); if (tLabels.shape().numDimensions() > 1) { - tLabels = getTF().shape.flatten(tLabels); + tLabels = tf.shape.flatten(tLabels); } - Operand tPredictions = cast(getTF(), predictions, type); + Operand tPredictions = cast(tf, predictions, type); if (tPredictions.shape().numDimensions() > 1) { - tPredictions = getTF().shape.flatten(tPredictions); + tPredictions = tf.shape.flatten(tPredictions); } - Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { - tSampleWeights = getTF().shape.flatten(tSampleWeights); + tSampleWeights = tf.shape.flatten(tSampleWeights); } // Accumulate the prediction to current confusion matrix. Operand currentCM = MetricsHelper.confusionMatrix( - getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); - return Collections.singletonList(getTF().assignAdd(totalConfusionMatrix, currentCM)); + tf, tLabels, tPredictions, tf.constant(numClasses), tSampleWeights, type); + return Collections.singletonList(tf.assignAdd(totalConfusionMatrix, currentCM)); } /** {@inheritDoc} */ @Override - public Operand result() { - Ops tf = getTF(); + public Operand result(Ops tf, Class resultType) { + init(tf); Operand sumOverRow = tf.reduceSum(totalConfusionMatrix, tf.constant(0)); Operand sumOverCol = tf.reduceSum(totalConfusionMatrix, tf.constant(1)); Operand truePositives = @@ -202,6 +198,6 @@ public Operand result() { Operand iou = tf.math.divNoNan(truePositives, denominator); Operand iouSum = tf.reduceSum(iou, allAxes(tf, iou)); - return tf.math.divNoNan(iouSum, numValidEntries); + return cast(tf, tf.math.divNoNan(iouSum, numValidEntries), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index 915d281e44b..50bc52b7d2f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -17,6 +17,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; import java.util.List; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; @@ -38,115 +39,134 @@ */ public class MeanRelativeError extends Mean { private Operand normalizer; + private float[] normalizerFloat; + private double[] normalizerDouble; /** * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * - * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanRelativeError(Ops tf, float[] normalizer, long seed, Class type) { - this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + protected MeanRelativeError(float[] normalizer, long seed, Class type) { + this(null, normalizer, seed, type); } /** * Creates a MeanRelativeError metric * - * @param tf the TensorFlow Ops * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. * @param normalizer The normalizer values with same shape as predictions. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanRelativeError(Ops tf, String name, float[] normalizer, long seed, Class type) { - this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + protected MeanRelativeError(String name, float[] normalizer, long seed, Class type) { + super(name, seed, type); + this.normalizerFloat = normalizer; } /** * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * - * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanRelativeError(Ops tf, double[] normalizer, long seed, Class type) { - this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + protected MeanRelativeError(double[] normalizer, long seed, Class type) { + this(null, normalizer, seed, type); } /** * Creates a MeanRelativeError metric * - * @param tf the TensorFlow Ops * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. * @param normalizer The normalizer values with same shape as predictions. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanRelativeError(Ops tf, String name, double[] normalizer, long seed, Class type) { - this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + protected MeanRelativeError(String name, double[] normalizer, long seed, Class type) { + super(name, seed, type); + this.normalizerDouble = normalizer; } /** * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * - * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanRelativeError(Ops tf, Operand normalizer, long seed, Class type) { - this(tf, null, normalizer, seed, type); + protected MeanRelativeError(Operand normalizer, long seed, Class type) { + this(null, normalizer, seed, type); } /** * Creates a MeanRelativeError metric * - * @param tf the TensorFlow ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param normalizer The normalizer values with same shape as predictions. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanRelativeError( - Ops tf, String name, Operand normalizer, long seed, Class type) { - super(tf, name, seed, type); + protected MeanRelativeError(String name, Operand normalizer, long seed, Class type) { + super(name, seed, type); this.normalizer = normalizer; } + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized()) { + super.init(tf); + if (normalizer == null) { + if (normalizerDouble.length > 0) { + normalizer = cast(tf, tf.constant(normalizerDouble), getInternalType()); + } else if (normalizerFloat.length > 0) { + normalizer = cast(tf, tf.constant(normalizerFloat), getInternalType()); + } + } + setInitialized(true); + } + } + /** * Accumulates metric statistics. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels The ground truth values. * @param predictions The predicted values. Must be the same shape as the normalizer. * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an * {@code Operand} whose rank is either 0, or the same rank as {@code labels}, and must be * broadcastable to {@code labels}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return a List of Operations to update the metric state */ @Override public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); Operand tSampleWeights = - sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + sampleWeights != null ? cast(tf, sampleWeights, getInternalType()) : null; - LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions); tPredictions = tuple.getTarget(); tLabels = tuple.getLabels(); - tuple = LossesHelper.removeSqueezableDimensions(getTF(), normalizer, tPredictions); + tuple = LossesHelper.removeSqueezableDimensions(tf, normalizer, tPredictions); normalizer = tuple.getLabels(); tPredictions = tuple.getTarget(); @@ -157,12 +177,9 @@ public List updateStateList( tPredictions.shape(), tLabels.shape())); Operand relativeErrors = - getTF() - .math - .divNoNan( - getTF().math.abs(getTF().math.sub(tLabels, tPredictions)), this.getNormalizer()); + tf.math.divNoNan(tf.math.abs(tf.math.sub(tLabels, tPredictions)), this.getNormalizer()); - return super.updateStateList(relativeErrors, tSampleWeights); + return super.updateStateList(tf, relativeErrors, tSampleWeights); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index fd8be29875e..d3f5bea03bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -14,15 +14,16 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -41,35 +42,52 @@ * * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper - implements LossMetric { +public class MeanSquaredError extends MeanBaseMetricWrapper + implements LossMetric { + + /** + * Creates a Mean Absolute Error metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanSquaredError(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Mean Absolute Error metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public MeanSquaredError(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public MeanSquaredError(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the mean squared error between the labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels. Must be the same shape as predictions. * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Computes the mean squared error between the labels and predictions. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanSquaredError(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.meanSquaredError(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 4728cbab12f..34512ef1c34 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -14,49 +14,67 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError extends MeanMetricWrapper - implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanBaseMetricWrapper + implements LossMetric { + + /** + * Creates a Mean Absolute Error metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanSquaredLogarithmicError(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Mean Absolute Error metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public MeanSquaredLogarithmicError(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the mean squared logarithmic error between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Mean squared logarithmic error values, shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.meanSquaredLogarithmicError(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index fa86125dbe7..3b4f4e9a73a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -35,7 +35,7 @@ * * @param The data type for the metric result */ -public class MeanTensor extends Metric { +public class MeanTensor extends BaseMetric { public static final String TOTAL = "total"; public static final String COUNT = "count"; private final String totalName; @@ -46,59 +46,50 @@ public class MeanTensor extends Metric { private Variable count; private Assign totalInitializer; private Assign countInitializer; - private boolean initialized; /** * Creates a MeanTensor metric, using {@link Class#getSimpleName()} as the name * - * @param tf the TensorFlow ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public MeanTensor(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + public MeanTensor(long seed, Class type) { + this(null, seed, type); } /** * Creates a MeanTensor metric * - * @param tf the TensorFlow ops * @param name the name of this metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public MeanTensor(Ops tf, String name, long seed, Class type) { - super(tf, name, seed); + public MeanTensor(String name, long seed, Class type) { + super(name, seed); this.type = type; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); } - /** - * Creates the Operations that initialize the total and count variables. - * - * @param shape the shape of the variables - * @return true if the variables need initialization, otherwise false; - */ - private boolean init(Shape shape) { - if (!initialized) { - this.shape = shape; + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized() && shape != null) { + setTF(tf); Zeros zeros = new Zeros<>(); - Operand zero = zeros.call(getTF(), getTF().constant(shape), type); + Operand zero = zeros.call(tf, tf.constant(shape), type); if (total == null) { - total = getTF().withName(totalName).withInitScope().variable(zero); - totalInitializer = getTF().assign(total, zero); + total = tf.withName(totalName).withInitScope().variable(zero); + totalInitializer = tf.assign(total, zero); } if (count == null) { - count = getTF().withName(countName).withInitScope().variable(zero); - countInitializer = getTF().assign(count, zero); + count = tf.withName(countName).withInitScope().variable(zero); + countInitializer = tf.assign(count, zero); } - this.initialized = true; - return true; - } else { - return false; + setInitialized(true); } } @@ -113,19 +104,19 @@ private boolean init(Shape shape) { */ @Override public List updateStateList( - Operand values, Operand sampleWeights) { - Ops tf = getTF(); + Ops tf, Operand values, Operand sampleWeights) { + if (shape == null) { + shape = values.shape(); + } + init(tf); Operand tValues = cast(tf, values, type); Operand tSampleWeights = sampleWeights == null ? null : cast(tf, sampleWeights, type); - // update the shape if it is the first call. - boolean needsInitialization = init(values.shape()); - - if (!this.shape.equals(values.shape())) { + if (!shape.equals(values.shape())) { throw new IllegalArgumentException( String.format( "MeanTensor input values must always have the same shape. Expected shape (set during the first call): %s. Got %s", - this.shape.toString(), values.shape().toString())); + shape.toString(), values.shape().toString())); } Operand numValues = tf.onesLike(tValues); @@ -153,12 +144,7 @@ public List updateStateList( tValues = tf.math.mul(tValues, tSampleWeights); } - List controlOpsPre = new ArrayList<>(); - if (needsInitialization) { - controlOpsPre.add(countInitializer); - controlOpsPre.add(totalInitializer); - } - Ops tf1 = tf.withSubScope("variables").withControlDependencies(controlOpsPre); + Ops tf1 = tf.withSubScope("MeanTensor.variables"); List controlOps = new ArrayList<>(); controlOps.add(tf1.assignAdd(this.count, numValues)); @@ -168,12 +154,13 @@ public List updateStateList( /** {@inheritDoc} */ @Override - public Operand result() { - if (!this.initialized) { - throw new IllegalStateException( - "MeanTensor does not have any result yet. Please use `.update_state(value)` before retrieving the result."); + public Operand result(Ops tf, Class resultType) { + init(tf); + if (!isInitialized()) { + return cast(tf, tf.constant(0), resultType); + } else { + return cast(tf, tf.math.divNoNan(total, count), resultType); } - return getTF().math.divNoNan(total, count); } /** @return the total */ @@ -188,10 +175,15 @@ public Variable getCount() { /** {@inheritDoc} */ @Override - public Op resetStates() { - List controlOpsPre = new ArrayList<>(); - controlOpsPre.add(countInitializer); - controlOpsPre.add(totalInitializer); - return getTF().withSubScope("resetStates").withControlDependencies(controlOpsPre).noOp(); + public Op resetStates(Ops tf) { + init(tf); + if (!isInitialized()) { + return tf.withSubScope("resetStates").noOp(); + } else { + List controlOpsPre = new ArrayList<>(); + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + return tf.withSubScope("resetStates").withControlDependencies(controlOpsPre).noOp(); + } } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 468919e696d..c8c1df607c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,182 +14,109 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import java.util.List; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import java.util.Collections; -import java.util.List; - -/** - * Base class for Metrics - * - * @param The data type for the metric result - */ -public abstract class Metric { - - /** The TensorFlow Ops */ - private final Ops tf; - - /** The seed for random number generation */ - private final long seed; - - /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ - private final String name; - - /** - * Creates a Metric with a name of {@link Class#getSimpleName()} - * - * @param tf the TensorFlow Ops - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - */ - protected Metric(Ops tf, long seed) { - this(tf, null, seed); - } - - /** - * Creates a Metric - * - * @param tf the TensorFlow Ops - * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - */ - protected Metric(Ops tf, String name, long seed) { - if (!tf.scope().env().isGraph()) { - throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); - } - this.seed = seed; - this.name = name != null ? name : this.getClass().getSimpleName(); - this.tf = tf.withName(this.getClass().getSimpleName()); - } +/** Interface for metrics */ +interface Metric { /** * Creates a List of Operations to update the metric state based on input values. * - *

This is an empty implementation that should be overridden in a subclass, if needed. - * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. encapsulating a {@link + * Graph} environment. * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand values, Operand sampleWeights) { - return Collections.EMPTY_LIST; - } + List updateStateList( + Ops tf, Operand values, Operand sampleWeights); /** * Creates a List of Operations to update the metric state based on labels and predictions. * *

This is an empty implementation that should be overridden in a sub class, if needed. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( + List updateStateList( + Ops tf, Operand labels, Operand predictions, - Operand sampleWeights) { - return Collections.EMPTY_LIST; - } + Operand sampleWeights); + + /** + * Gets the current result of the metric + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @param type the data type for the result + * @param the date type for the result + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return the result, possibly with control dependencies + */ + Operand result(Ops tf, Class type); + + /** + * Resets any state variables to their initial values + * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. + * @return the operation for doing the reset + */ + Op resetStates(Ops tf); /** * Creates a NoOp Operation with control dependencies to update the metric state * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. * @return the Operation to update the metric state */ - public final Op updateState( - Operand values, Operand sampleWeights) { - List controlOps = updateStateList(values, sampleWeights); - return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); - } + Op updateState( + Ops tf, Operand values, Operand sampleWeights); /** * Creates a NoOp Operation with control dependencies to update the metric state * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. * @return the Operation to update the metric state */ - public final Op updateState( + Op updateState( + Ops tf, Operand labels, Operand predictions, - Operand sampleWeights) { - List controlOps = updateStateList(labels, predictions, sampleWeights); - return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); - } - - /** - * Gets the current result of the metric - * - * @return the result, possibly with control dependencies - */ - public abstract Operand result(); - - /** - * Resets any state variables to their initial values - * - * @return the control operation for doing the reset - */ - public abstract Op resetStates(); + Operand sampleWeights); /** * Calls update state once, followed by a call to get the result * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. + * @param type the data type for the result + * @param the date type for the result + * @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. * @return the result, possibly with control dependencies */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { - List controlOps = updateStateList(values, sampleWeights); - Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); - return ltf.identity(result()); - } - - /** - * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. - * - * @param varName the base name for the variable - * @return the formatted variable name - */ - protected String getVariableName(String varName) { - return String.format("%s_%s", this.name, varName); - } - - /** - * Gets the TensorFlow Ops - * - * @return the TensorFlow Ops - */ - public Ops getTF() { - return tf; - } - - /** - * Gets the name of this metric. - * - * @return the name of this metric - */ - public String getName() { - return name; - } - - /** - * Gets the random number generator seed value - * - * @return the random number generator seed value - */ - public long getSeed() { - return seed; - } + Operand callOnce( + Ops tf, + Operand values, + Operand sampleWeights, + Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index a33750ac3f6..beb86ee5c0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -21,8 +24,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** Static methods for computing metrics. */ public class Metrics { @@ -41,10 +42,12 @@ public class Metrics { * //m.shape().toString == "[2]" * * - * @param tf the TensorFlow Ops. + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.. * @param labels the ground truth values. * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @param the data type for the predictions and results * @return the Operand for the Top K categorical accuracy value. */ @@ -71,15 +74,16 @@ public static Operand topKCategoricalAccuracy( * //m.shape().toString == "[2]" * * - * @param tf the TensorFlow Ops. + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.. * @param labels the ground truth values. * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @param the data type for the predictions and results * @param the data type ofr the labels. * @return the Operand for the Sparse top K categorical accuracy value. */ - @SuppressWarnings("unchecked") public static Operand sparseTopKCategoricalAccuracy( Ops tf, Operand labels, Operand predictions, int k) { Operand tLabels = cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 2e4bde8ec55..706eb4fc385 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -14,48 +14,66 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the poisson loss metric between labels and predictions. * * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper implements LossMetric { +public class Poisson extends MeanBaseMetricWrapper implements LossMetric { + + /** + * Creates a Poisson metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Poisson(long seed, Class type) { + this(null, seed, type); + } /** * Creates a Poisson metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public Poisson(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public Poisson(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the Poisson loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Poisson loss value, shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.poisson(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, resultType); + Operand tPredictions = cast(tf, predictions, resultType); + return Losses.poisson(tf, tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index d81030ebedb..c1df3ac952e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -17,7 +17,6 @@ import static org.tensorflow.framework.utils.CastHelper.cast; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -54,7 +53,7 @@ * * @param The data type for the metric result */ -public class Precision extends Metric { +public class Precision extends BaseMetric { public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; public static final float DEFAULT_THRESHOLD = 0.5f; @@ -66,6 +65,7 @@ public class Precision extends Metric { private final String falsePositivesName; private final Class type; private final List initializers = new ArrayList<>(); + private final Zeros zeros = new Zeros<>(); private Variable truePositives; private Variable falsePositives; @@ -73,35 +73,32 @@ public class Precision extends Metric { * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId * values and with a threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision(Ops tf, long seed, Class type) { - this(tf, null, null, null, null, seed, type); + public Precision(long seed, Class type) { + this(null, null, null, null, seed, type); } /** * Creates a Precision Metric with no topK or classId values with a threshold of {@link * #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision(Ops tf, String name, long seed, Class type) { - this(tf, name, null, null, null, seed, type); + public Precision(String name, long seed, Class type) { + this(name, null, null, null, seed, type); } /** * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId * values. * - * @param tf the TensorFlow Ops * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the * threshold is true, below is false). One metric value is generated for each threshold value. @@ -109,15 +106,14 @@ public Precision(Ops tf, String name, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision(Ops tf, float threshold, long seed, Class type) { - this(tf, null, new float[] {threshold}, null, null, seed, type); + public Precision(float threshold, long seed, Class type) { + this(null, new float[] {threshold}, null, null, seed, type); } /** * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId * values. * - * @param tf the TensorFlow Ops * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold @@ -126,14 +122,13 @@ public Precision(Ops tf, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision(Ops tf, float[] thresholds, long seed, Class type) { - this(tf, null, thresholds, null, null, seed, type); + public Precision(float[] thresholds, long seed, Class type) { + this(null, thresholds, null, null, seed, type); } /** * Creates a Precision Metric with no topK or classId values. * - * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared @@ -143,14 +138,13 @@ public Precision(Ops tf, float[] thresholds, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision(Ops tf, String name, float threshold, long seed, Class type) { - this(tf, name, new float[] {threshold}, null, null, seed, type); + public Precision(String name, float threshold, long seed, Class type) { + this(name, new float[] {threshold}, null, null, seed, type); } /** * Creates a Precision Metric with no topK or classId values. * - * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is @@ -161,14 +155,13 @@ public Precision(Ops tf, String name, float threshold, long seed, Class type) * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision(Ops tf, String name, float[] thresholds, long seed, Class type) { - this(tf, name, thresholds, null, null, seed, type); + public Precision(String name, float[] thresholds, long seed, Class type) { + this(name, thresholds, null, null, seed, type); } /** * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * - * @param tf the TensorFlow Ops * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the * threshold is true, below is false). One metric value is generated for each threshold value. @@ -180,15 +173,13 @@ public Precision(Ops tf, String name, float[] thresholds, long seed, Class ty * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision( - Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { - this(tf, null, new float[] {threshold}, topK, classId, seed, type); + public Precision(float threshold, Integer topK, Integer classId, long seed, Class type) { + this(null, new float[] {threshold}, topK, classId, seed, type); } /** * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * - * @param tf the TensorFlow Ops * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold @@ -201,15 +192,13 @@ public Precision( * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Precision( - Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { - this(tf, null, thresholds, topK, classId, seed, type); + public Precision(float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(null, thresholds, topK, classId, seed, type); } /** * Creates a Precision Metric. * - * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is compared @@ -224,20 +213,13 @@ public Precision( * @param type the data type for the variables */ public Precision( - Ops tf, - String name, - float threshold, - Integer topK, - Integer classId, - long seed, - Class type) { - this(tf, name, new float[] {threshold}, topK, classId, seed, type); + String name, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(name, new float[] {threshold}, topK, classId, seed, type); } /** * Creates a Precision Metric. * - * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is @@ -253,14 +235,8 @@ public Precision( * @param type the data type for the variables */ public Precision( - Ops tf, - String name, - float[] thresholds, - Integer topK, - Integer classId, - long seed, - Class type) { - super(tf, name, seed); + String name, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + super(name, seed); this.type = type; this.truePositivesName = this.getVariableName(TRUE_POSITIVES); this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); @@ -268,23 +244,25 @@ public Precision( this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; this.topK = topK; this.classId = classId; - - init(); } - /** Initializes the variables */ - private void init() { - Ops tf = getTF(); - Zeros zeros = new Zeros<>(); - Operand zero = zeros.call(tf, tf.constant(Shape.of(thresholds.length)), type); - - if (this.truePositives == null) { - this.truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); - initializers.add(tf.assign(truePositives, zero)); - } - if (this.falsePositives == null) { - this.falsePositives = tf.withName(falsePositivesName).withInitScope().variable(zero); - initializers.add(tf.assign(falsePositives, zero)); + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized()) { + setTF(tf); + Operand zero = zeros.call(tf, tf.constant(Shape.of(thresholds.length)), type); + + if (this.truePositives == null) { + this.truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); + initializers.add(tf.assign(truePositives, zero)); + } + if (this.falsePositives == null) { + this.falsePositives = tf.withName(falsePositivesName).withInitScope().variable(zero); + initializers.add(tf.assign(falsePositives, zero)); + } + setInitialized(true); } } @@ -299,12 +277,12 @@ private void init() { * @return a List of Operations to update the metric state. */ @Override - @SuppressWarnings("unchecked") public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Ops tf = getTF(); + init(tf); Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives); @@ -313,11 +291,10 @@ public List updateStateList( Operand tLabels = cast(tf, labels, type); Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; - return new ArrayList( + return new ArrayList<>( MetricsHelper.updateConfusionMatrixVariables( tf, confusionMatrix, - Collections.EMPTY_MAP, tLabels, tPredictions, tf.constant(thresholds), @@ -330,23 +307,27 @@ public List updateStateList( /** {@inheritDoc} */ @Override - public Operand result() { - Ops tf = getTF(); + public Operand result(Ops tf, Class resultType) { + init(tf); Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - return thresholds.length == 1 - ? tf.reshape( - tf.slice( - result, - tf.expandDims(tf.constant(0), tf.constant(0)), - tf.expandDims(tf.constant(1), tf.constant(0))), - tf.constant(Shape.scalar())) - : result; + return cast( + tf, + thresholds.length == 1 + ? tf.reshape( + tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))), + tf.constant(Shape.scalar())) + : result, + resultType); } /** {@inheritDoc} */ @Override - public Op resetStates() { - return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + public Op resetStates(Ops tf) { + init(tf); + return tf.withSubScope("resetStates").withControlDependencies(initializers).noOp(); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index a5285ff6b2d..f6c2251ddaf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -42,7 +42,6 @@ public class PrecisionAtRecall extends SensitivitySpecificity * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds * - * @param tf The TensorFlow Ops * @param recall the recall. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -50,15 +49,14 @@ public class PrecisionAtRecall extends SensitivitySpecificity * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { - this(tf, null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + public PrecisionAtRecall(float recall, long seed, Class type) { + this(null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of * thresholds * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param recall the recall. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed @@ -67,14 +65,13 @@ public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class type) { - this(tf, name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + public PrecisionAtRecall(String name, float recall, long seed, Class type) { + this(name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. * - * @param tf The TensorFlow Ops * @param recall the recall. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given * recall. @@ -84,14 +81,13 @@ public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Class type) { - this(tf, null, recall, numThresholds, seed, type); + public PrecisionAtRecall(float recall, int numThresholds, long seed, Class type) { + this(null, recall, numThresholds, seed, type); } /** * Creates a PrecisionRecall metric. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param recall the recall. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given @@ -102,9 +98,8 @@ public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Cla * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public PrecisionAtRecall( - Ops tf, String name, float recall, int numThresholds, long seed, Class type) { - super(tf, name, numThresholds, seed, type); + public PrecisionAtRecall(String name, float recall, int numThresholds, long seed, Class type) { + super(name, numThresholds, seed, type); if (recall < 0f || recall > 1f) throw new IllegalArgumentException("recall must be in the range [0, 1]."); this.recall = recall; @@ -112,9 +107,8 @@ public PrecisionAtRecall( /** {@inheritDoc} */ @Override - public Operand result() { - Ops tf = getTF(); - + public Operand result(Ops tf, Class resultType) { + init(tf); Operand div = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); Operand sub = tf.math.sub(div, cast(tf, tf.constant(recall), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); @@ -122,7 +116,7 @@ public Operand result() { Operand trueSlice = tf.slice(truePositives, minIndex, tf.constant(new int[] {1})); Operand falseSlice = tf.slice(falsePositives, minIndex, tf.constant(new int[] {1})); - return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + return cast(tf, tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)), resultType); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 7cdd01f0c56..4becc6e8908 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -17,10 +17,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; @@ -52,7 +52,7 @@ * * @param The data type for the metric result */ -public class Recall extends Metric { +public class Recall extends BaseMetric { public static final float DEFAULT_THRESHOLD = 0.5f; public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; @@ -71,35 +71,32 @@ public class Recall extends Metric { * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set * to null, and thresholds set to {@link #DEFAULT_THRESHOLD} * - * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, long seed, Class type) { - this(tf, null, null, null, null, seed, type); + public Recall(long seed, Class type) { + this(null, null, null, null, seed, type); } /** * Creates a Recall metric with topK and classId set to null and thresholds set to {@link * #DEFAULT_THRESHOLD}. * - * @param tf The TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, String name, long seed, Class type) { - this(tf, name, null, null, null, seed, type); + public Recall(String name, long seed, Class type) { + this(name, null, null, null, seed, type); } /** * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set * to null. * - * @param tf The TensorFlow Ops * @param threshold A threshold is compared with prediction values to determine the truth value of * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to * {@link #DEFAULT_THRESHOLD}. @@ -107,15 +104,14 @@ public Recall(Ops tf, String name, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, float threshold, long seed, Class type) { - this(tf, null, threshold, null, null, seed, type); + public Recall(float threshold, long seed, Class type) { + this(null, threshold, null, null, seed, type); } /** * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set * to null. * - * @param tf The TensorFlow Ops * @param thresholds A threshold is compared with prediction values to determine the truth value * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults * to {@link #DEFAULT_THRESHOLD}. @@ -123,14 +119,13 @@ public Recall(Ops tf, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, float[] thresholds, long seed, Class type) { - this(tf, null, thresholds, null, null, seed, type); + public Recall(float[] thresholds, long seed, Class type) { + this(null, thresholds, null, null, seed, type); } /** * Creates a Recall metric with topK and classId set to null. * - * @param tf The TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param threshold A threshold is compared with prediction values to determine the truth value of @@ -140,14 +135,13 @@ public Recall(Ops tf, float[] thresholds, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, String name, float threshold, long seed, Class type) { - this(tf, name, threshold, null, null, seed, type); + public Recall(String name, float threshold, long seed, Class type) { + this(name, threshold, null, null, seed, type); } /** * Creates a Recall metric with topK and classId set to null. * - * @param tf The TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param thresholds A threshold is compared with prediction values to determine the truth value @@ -157,15 +151,14 @@ public Recall(Ops tf, String name, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, String name, float[] thresholds, long seed, Class type) { - this(tf, name, thresholds, null, null, seed, type); + public Recall(String name, float[] thresholds, long seed, Class type) { + this(name, thresholds, null, null, seed, type); } /** * Creates a Recall metric with a name of {@link Class#getSimpleName()} and using a threshold * value of {@link #DEFAULT_THRESHOLD}. * - * @param tf The TensorFlow Ops * @param topK An optional value specifying the top-k predictions to consider when calculating * precision. * @param classId Optional Integer class ID for which we want binary metrics. This must be in the @@ -174,14 +167,13 @@ public Recall(Ops tf, String name, float[] thresholds, long seed, Class type) * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, Integer topK, Integer classId, long seed, Class type) { - this(tf, null, null, topK, classId, seed, type); + public Recall(Integer topK, Integer classId, long seed, Class type) { + this(null, null, topK, classId, seed, type); } /** * Creates a Recall metric using a threshold value of {@link #DEFAULT_THRESHOLD}. * - * @param tf The TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param topK An optional value specifying the top-k predictions to consider when calculating @@ -192,14 +184,13 @@ public Recall(Ops tf, Integer topK, Integer classId, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, String name, Integer topK, Integer classId, long seed, Class type) { - this(tf, name, null, topK, classId, seed, type); + public Recall(String name, Integer topK, Integer classId, long seed, Class type) { + this(name, null, topK, classId, seed, type); } /** * Creates a Recall metric with a name of {@link Class#getSimpleName()} * - * @param tf The TensorFlow Ops * @param threshold A threshold is compared with prediction values to determine the truth value of * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to * {@link #DEFAULT_THRESHOLD}. @@ -211,14 +202,13 @@ public Recall(Ops tf, String name, Integer topK, Integer classId, long seed, Cla * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall(Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { - this(tf, null, new float[] {threshold}, topK, classId, seed, type); + public Recall(float threshold, Integer topK, Integer classId, long seed, Class type) { + this(null, new float[] {threshold}, topK, classId, seed, type); } /** * Creates a Recall metric with a name of {@link Class#getSimpleName()} * - * @param tf The TensorFlow Ops * @param thresholds A threshold is compared with prediction values to determine the truth value * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults * to {@link #DEFAULT_THRESHOLD}. @@ -230,15 +220,13 @@ public Recall(Ops tf, float threshold, Integer topK, Integer classId, long seed, * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public Recall( - Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { - this(tf, null, thresholds, topK, classId, seed, type); + public Recall(float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(null, thresholds, topK, classId, seed, type); } /** * Creates a Recall metric. * - * @param tf The TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param threshold A threshold is compared with prediction values to determine the truth value of @@ -253,20 +241,13 @@ public Recall( * @param type the data type for the variables */ public Recall( - Ops tf, - String name, - float threshold, - Integer topK, - Integer classId, - long seed, - Class type) { - this(tf, name, new float[] {threshold}, topK, classId, seed, type); + String name, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(name, new float[] {threshold}, topK, classId, seed, type); } /** * Creates a Recall metric. * - * @param tf The TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param thresholds A threshold is compared with prediction values to determine the truth value @@ -281,14 +262,8 @@ public Recall( * @param type the data type for the variables */ public Recall( - Ops tf, - String name, - float[] thresholds, - Integer topK, - Integer classId, - long seed, - Class type) { - super(tf, name, seed); + String name, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + super(name, seed); this.type = type; this.truePositivesName = this.getVariableName(TRUE_POSITIVES); this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); @@ -297,51 +272,58 @@ public Recall( this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; this.topK = topK; this.classId = classId; - - init(); } - /** Initializes the Variables */ - private void init() { - Ops tf = getTF(); - Zeros zeros = new Zeros<>(); - Operand zero = zeros.call(tf, tf.constant(Shape.of(this.thresholds.length)), type); - if (truePositives == null) { - - truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); - initializers.add(tf.assign(truePositives, zero)); - } - - if (this.falseNegatives == null) { - - falseNegatives = tf.withName(falseNegativesName).withInitScope().variable(zero); - initializers.add(tf.assign(falseNegatives, zero)); + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized()) { + setTF(tf); + Zeros zeros = new Zeros<>(); + Operand zero = zeros.call(tf, tf.constant(Shape.of(this.thresholds.length)), type); + if (truePositives == null) { + + truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); + initializers.add(tf.assign(truePositives, zero)); + } + + if (this.falseNegatives == null) { + + falseNegatives = tf.withName(falseNegativesName).withInitScope().variable(zero); + initializers.add(tf.assign(falseNegatives, zero)); + } + setInitialized(true); } } /** {@inheritDoc} */ @Override - public Op resetStates() { - return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + public Op resetStates(Ops tf) { + init(tf); + return tf.withSubScope("resetStates").withControlDependencies(initializers).noOp(); } /** * Accumulates true positive and false negative statistics. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. The TensorFlow Ops * @param labels the labels The ground truth values, with the same dimensions as predictions. Will * be cast to {@link TBool}. * @param predictions the predictions, each element must be in the range {@code [0, 1]}. * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * * the same rank as labels, and must be broadcastable to labels. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return a List of Operations to update the metric state. */ @Override - @SuppressWarnings("unchecked") public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Ops tf = getTF(); + init(tf); Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives); @@ -353,7 +335,6 @@ public List updateStateList( return MetricsHelper.updateConfusionMatrixVariables( tf, confusionMatrix, - Collections.EMPTY_MAP, tLabels, tPredictions, tf.constant(thresholds), @@ -365,13 +346,16 @@ public List updateStateList( } @Override - public Operand result() { - Ops tf = getTF(); + public Operand result(Ops tf, Class resultType) { + init(tf); Operand result = tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); - return this.thresholds.length == 1 - ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) - : result; + return cast( + tf, + this.thresholds.length == 1 + ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) + : result, + resultType); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index 2386087e8a2..be821681704 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -47,7 +47,6 @@ public class RecallAtPrecision extends SensitivitySpecificity * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds * - * @param tf The TensorFlow Ops * @param precision the precision. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -55,15 +54,14 @@ public class RecallAtPrecision extends SensitivitySpecificity * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { - this(tf, null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + public RecallAtPrecision(float precision, long seed, Class type) { + this(null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of * thresholds * - * @param tf The TensorFlow Ops * @param name the name of the metric. If null, defaults to {@link Class#getSimpleName()} * @param precision the precision. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed @@ -72,14 +70,13 @@ public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class type) { - this(tf, name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + public RecallAtPrecision(String name, float precision, long seed, Class type) { + this(name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. * - * @param tf The TensorFlow Ops * @param precision the precision. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given * recall. @@ -89,14 +86,13 @@ public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class< * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range * [0-1]. */ - public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, Class type) { - this(tf, null, precision, numThresholds, seed, type); + public RecallAtPrecision(float precision, int numThresholds, long seed, Class type) { + this(null, precision, numThresholds, seed, type); } /** * Creates a PrecisionRecall metric. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param precision the precision. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given @@ -108,8 +104,8 @@ public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, * [0-1]. */ public RecallAtPrecision( - Ops tf, String name, float precision, int numThresholds, long seed, Class type) { - super(tf, name, numThresholds, seed, type); + String name, float precision, int numThresholds, long seed, Class type) { + super(name, numThresholds, seed, type); if (precision < 0f || precision > 1f) throw new IllegalArgumentException("recall must be in the range [0, 1]."); this.precision = precision; @@ -117,9 +113,8 @@ public RecallAtPrecision( /** {@inheritDoc} */ @Override - public Operand result() { - Ops tf = getTF(); - + public Operand result(Ops tf, Class resultType) { + init(tf); Operand precisions = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); Operand recalls = @@ -130,10 +125,13 @@ public Operand result() { Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); Operand gather = tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); - return tf.select( - feasibleExists, - tf.reduceMax(gather, allAxes(tf, gather)), - cast(tf, tf.constant(0), getType())); + return cast( + tf, + tf.select( + feasibleExists, + tf.reduceMax(gather, allAxes(tf, gather)), + cast(tf, tf.constant(0), getType())), + resultType); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 8b0b06e788d..65a2ce2e687 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -35,27 +35,25 @@ public class RootMeanSquaredError extends Mean { /** * Creates a RootMeanSquaredError metric with a name of {@link Class#getSimpleName()} * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public RootMeanSquaredError(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + public RootMeanSquaredError(long seed, Class type) { + this(null, seed, type); } /** * Creates a RootMeanSquaredError metric * - * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public RootMeanSquaredError(String name, long seed, Class type) { + super(name, seed, type); } /** @@ -69,28 +67,30 @@ public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { */ @Override public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); Operand tSampleWeights = - sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + sampleWeights != null ? cast(tf, sampleWeights, getInternalType()) : null; - LossTuple ops = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions); tPredictions = ops.getTarget(); tLabels = ops.getLabels(); Operand errorSquared = - cast(getTF(), getTF().math.squaredDifference(tPredictions, tLabels), getResultType()); + cast(tf, tf.math.squaredDifference(tPredictions, tLabels), getInternalType()); - return super.updateStateList(errorSquared, tSampleWeights); + return super.updateStateList(tf, errorSquared, tSampleWeights); } /** {@inheritDoc} */ @Override - public Operand result() { - return getTF().math.sqrt(getTF().math.divNoNan(this.total, this.count)); + public Operand result(Ops tf, Class resultType) { + init(tf); + return cast(tf, tf.math.sqrt(tf.math.divNoNan(this.total, this.count)), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 3892af920e9..0e25c42e6fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -51,7 +51,6 @@ public class SensitivityAtSpecificity extends SensitivitySpec * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds * - * @param tf The TensorFlow Ops * @param specificity the specificity. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -59,15 +58,14 @@ public class SensitivityAtSpecificity extends SensitivitySpec * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ - public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class type) { - this(tf, null, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + public SensitivityAtSpecificity(float specificity, long seed, Class type) { + this(null, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number * of thresholds * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param specificity the specificity. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed @@ -76,15 +74,13 @@ public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class t * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ - public SensitivityAtSpecificity( - Ops tf, String name, float specificity, long seed, Class type) { - this(tf, name, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + public SensitivityAtSpecificity(String name, float specificity, long seed, Class type) { + this(name, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. * - * @param tf The TensorFlow Ops * @param specificity the specificity. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given * specificity. @@ -94,15 +90,13 @@ public SensitivityAtSpecificity( * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ - public SensitivityAtSpecificity( - Ops tf, float specificity, int numThresholds, long seed, Class type) { - this(tf, null, specificity, numThresholds, seed, type); + public SensitivityAtSpecificity(float specificity, int numThresholds, long seed, Class type) { + this(null, specificity, numThresholds, seed, type); } /** * Creates a PrecisionRecall metric. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param specificity the specificity. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given @@ -114,8 +108,8 @@ public SensitivityAtSpecificity( * [0-1]. */ public SensitivityAtSpecificity( - Ops tf, String name, float specificity, int numThresholds, long seed, Class type) { - super(tf, name, numThresholds, seed, type); + String name, float specificity, int numThresholds, long seed, Class type) { + super(name, numThresholds, seed, type); if (specificity < 0f || specificity > 1f) throw new IllegalArgumentException("specificity must be in the range [0, 1]."); this.specificity = specificity; @@ -123,8 +117,8 @@ public SensitivityAtSpecificity( /** {@inheritDoc} */ @Override - public Operand result() { - Ops tf = getTF(); + public Operand result(Ops tf, Class resultType) { + init(tf); Operand specificities = tf.math.divNoNan(trueNegatives, tf.math.add(trueNegatives, falsePositives)); Operand sub = tf.math.sub(specificities, cast(tf, tf.constant(specificity), getType())); @@ -133,7 +127,7 @@ public Operand result() { Operand trueSlice = tf.slice(truePositives, minIndex, tf.constant(new int[] {1})); Operand falseSlice = tf.slice(falseNegatives, minIndex, tf.constant(new int[] {1})); - return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + return cast(tf, tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)), resultType); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 10d33c31508..e1f097994e0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -17,9 +17,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; import java.util.Collections; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Squeeze; @@ -64,48 +65,52 @@ * * @param The data type for the metric result */ -public class SparseCategoricalAccuracy extends MeanMetricWrapper - implements LossMetric { +public class SparseCategoricalAccuracy extends MeanBaseMetricWrapper + implements LossMetric { /** * Creates a SparseCategoricalAccuracy metric, using name of {@link Class#getSimpleName()}. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type The data type for the metric result */ - public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + public SparseCategoricalAccuracy(long seed, Class type) { + this(null, seed, type); } /** * Creates a SparseCategoricalAccuracy metric. * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null use {@link Class#getSimpleName()} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type of the metric result. */ - public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public SparseCategoricalAccuracy(String name, long seed, Class type) { + super(name, seed, type); super.setLoss(this); } /** * Calculates how often predictions matches integer labels. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. the TensorFlowOps * @param labels Integer ground truth values. * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Sparse categorical accuracy values. */ @Override - public Operand call( - Operand labels, Operand predictions) { - - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); Shape predShape = predictions.asOutput().shape(); Shape labelsShape = labels.asOutput().shape(); long predictionsRank = predShape.numDimensions(); @@ -115,15 +120,12 @@ public Operand call( if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE && labelsShape.size((int) labelsRank - 1) == 1) { - tLabels = getTF().squeeze(tLabels, Squeeze.axis(Collections.singletonList(labelsRank - 1L))); + tLabels = tf.squeeze(tLabels, Squeeze.axis(Collections.singletonList(labelsRank - 1L))); } Operand argMaxPred = - cast( - getTF(), - getTF().math.argMax(tPredictions, getTF().constant(-1L), TInt64.class), - getResultType()); + cast(tf, tf.math.argMax(tPredictions, tf.constant(-1L), TInt64.class), getInternalType()); - Equal equals = getTF().math.equal(tLabels, argMaxPred); - return getTF().dtypes.cast(equals, getResultType()); + Equal equals = tf.math.equal(tLabels, argMaxPred); + return cast(tf, equals, resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 04555d85b66..501bdd81770 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -14,15 +14,16 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. @@ -32,16 +33,29 @@ * * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy extends MeanMetricWrapper - implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanBaseMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; + /** + * Creates a SparseCategoricalCrossentropy metric using {@link Class#getSimpleName()} for the + * metric name. + * + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. + * @param axis The dimension along which the entropy is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public SparseCategoricalCrossentropy(boolean fromLogits, int axis, long seed, Class type) { + this(null, fromLogits, axis, seed, type); + } /** * Creates a SparseCategoricalCrossentropy metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. @@ -51,8 +65,8 @@ public class SparseCategoricalCrossentropy extends MeanMetric * @param type the type for the variables and result */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, int axis, long seed, Class type) { - super(tf, name, seed, type); + String name, boolean fromLogits, int axis, long seed, Class type) { + super(name, seed, type); setLoss(this); this.fromLogits = fromLogits; this.axis = axis; @@ -61,15 +75,25 @@ public SparseCategoricalCrossentropy( /** * Calculates how often predictions matches integer labels. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels Integer ground truth values. * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Sparse categorical accuracy values. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); + return cast( + tf, + Losses.sparseCategoricalCrossentropy(tf, tLabels, tPredictions, fromLogits, axis), + resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java index 29dc91298d3..e4bc6fcda2a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -14,51 +14,75 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * Computes how often integer targets are in the top `K` predictions. * * @param The data type for the metric result */ -public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper - implements LossMetric { +public class SparseTopKCategoricalAccuracy extends MeanBaseMetricWrapper + implements LossMetric { public static final int DEFAULT_K = 5; /** Number of top elements to look at for computing accuracy. */ private final int k; + /** + * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top + * elements using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(long seed, Class type) { + this(null, DEFAULT_K, seed, type); + } + + /** + * Creates a SparseTopKCategoricalAccuracy metric using {@link Class#getSimpleName()} for the + * metric name. + * + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(int k, long seed, Class type) { + this(null, k, seed, type); + } + /** * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top * elements. * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the date type for the result */ - public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_K, seed, type); + public SparseTopKCategoricalAccuracy(String name, long seed, Class type) { + this(name, DEFAULT_K, seed, type); } /** * Creates a SparseTopKCategoricalAccuracy metric. * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param k Number of top elements to look at for computing accuracy. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the date type for the result */ - public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { - super(tf, name, seed, type); + public SparseTopKCategoricalAccuracy(String name, int k, long seed, Class type) { + super(name, seed, type); this.k = k; setLoss(this); } @@ -66,15 +90,23 @@ public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Clas /** * Computes how often integer targets are in the top {@code K} predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Sparse top K categorical accuracy value. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Metrics.sparseTopKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); + return cast( + tf, Metrics.sparseTopKCategoricalAccuracy(tf, tLabels, tPredictions, k), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index aa8eeb062b3..03f7b0fc472 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -50,7 +50,6 @@ public class SpecificityAtSensitivity extends SensitivitySpec * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds * - * @param tf The TensorFlow Ops * @param sensitivity the sensitivity. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -58,15 +57,14 @@ public class SpecificityAtSensitivity extends SensitivitySpec * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ - public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class type) { - this(tf, null, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + public SpecificityAtSensitivity(float sensitivity, long seed, Class type) { + this(null, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number * of thresholds * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param sensitivity the sensitivity. A scalar value in range [0, 1] * @param seed the seed for random number generation. An initializer created with a given seed @@ -75,15 +73,13 @@ public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class t * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ - public SpecificityAtSensitivity( - Ops tf, String name, float sensitivity, long seed, Class type) { - this(tf, name, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + public SpecificityAtSensitivity(String name, float sensitivity, long seed, Class type) { + this(name, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); } /** * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. * - * @param tf The TensorFlow Ops * @param sensitivity the sensitivity. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given * sensitivity. @@ -93,15 +89,13 @@ public SpecificityAtSensitivity( * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ - public SpecificityAtSensitivity( - Ops tf, float sensitivity, int numThresholds, long seed, Class type) { - this(tf, null, sensitivity, numThresholds, seed, type); + public SpecificityAtSensitivity(float sensitivity, int numThresholds, long seed, Class type) { + this(null, sensitivity, numThresholds, seed, type); } /** * Creates a PrecisionRecall metric. * - * @param tf The TensorFlow Ops * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} * @param sensitivity the sensitivity. A scalar value in range [0, 1] * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given @@ -113,8 +107,8 @@ public SpecificityAtSensitivity( * [0-1]. */ public SpecificityAtSensitivity( - Ops tf, String name, float sensitivity, int numThresholds, long seed, Class type) { - super(tf, name, numThresholds, seed, type); + String name, float sensitivity, int numThresholds, long seed, Class type) { + super(name, numThresholds, seed, type); if (sensitivity < 0f || sensitivity > 1f) throw new IllegalArgumentException("sensitivity must be in the range [0, 1]."); this.sensitivity = sensitivity; @@ -122,9 +116,8 @@ public SpecificityAtSensitivity( /** {@inheritDoc} */ @Override - public Operand result() { - - Ops tf = getTF(); + public Operand result(Ops tf, Class resultType) { + init(tf); Operand sensitivities = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); @@ -134,7 +127,7 @@ public Operand result() { Operand trueSlice = tf.slice(trueNegatives, minIndex, tf.constant(new int[] {1})); Operand falseSlice = tf.slice(falsePositives, minIndex, tf.constant(new int[] {1})); - return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + return cast(tf, tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)), resultType); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index e2ff208b8f5..f29aac4e200 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -14,50 +14,69 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the squared hinge loss metric between labels and predictions. * * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper implements LossMetric { +public class SquaredHinge extends MeanBaseMetricWrapper + implements LossMetric { + + /** + * Creates a SquaredHinge metric using {@link Class#getSimpleName()} for the metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public SquaredHinge(long seed, Class type) { + this(null, seed, type); + } /** * Creates a SquaredHinge metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public SquaredHinge(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + public SquaredHinge(String name, long seed, Class type) { + super(name, seed, type); setLoss(this); } /** * Computes the squared hinge loss between labels and predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels The ground truth values. {@code labels} values are expected to be -1 or 1. If * binary (0 or 1) labels are provided we will convert them to -1 or 1. shape = {@code * [batch_size, d0, .. dN]}. * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Squared hinge loss values. shape = {@code [batch_size, d0, .. dN-1]}. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.squaredHinge(getTF(), tLabels, tPredictions); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); + return cast(tf, Losses.squaredHinge(tf, tLabels, tPredictions), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index bcb1d7b9a36..4010e760d1c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.metrics; import org.tensorflow.framework.metrics.impl.Reduce; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** @@ -34,25 +33,23 @@ public class Sum extends Reduce { /** * Creates a Sum metric with a name of {@link Class#getSimpleName()} * - * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public Sum(Ops tf, long seed, Class type) { - super(tf, null, MetricReduction.SUM, seed, type); + public Sum(long seed, Class type) { + super(null, MetricReduction.SUM, seed, type); } /** * Creates a Sum metric. * - * @param tf The TensorFlow Ops * @param name the name of the metric instance. If null, defaults to {@link Class#getSimpleName()} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - public Sum(Ops tf, String name, long seed, Class type) { - super(tf, name, MetricReduction.SUM, seed, type); + public Sum(String name, long seed, Class type) { + super(name, MetricReduction.SUM, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index b630be5bcc2..7d6586325b2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -16,9 +16,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.metrics.impl.LossMetric; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -27,38 +28,62 @@ * * @param The data type for the metric result */ -public class TopKCategoricalAccuracy extends MeanMetricWrapper - implements LossMetric { +public class TopKCategoricalAccuracy extends MeanBaseMetricWrapper + implements LossMetric { public static final int DEFAULT_K = 5; /** Number of top elements to look at for computing accuracy. */ private final int k; + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of top + * elements to look at for computing accuracy and using {@link Class#getSimpleName()} for the + * metric name. + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public TopKCategoricalAccuracy(long seed, Class type) { + this(null, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric using {@link Class#getSimpleName()} for the metric + * name. + * + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public TopKCategoricalAccuracy(int k, long seed, Class type) { + this(null, k, seed, type); + } + /** * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of top * elements to look at for computing accuracy. * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type The data type for the metric result */ - public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_K, seed, type); + public TopKCategoricalAccuracy(String name, long seed, Class type) { + this(name, DEFAULT_K, seed, type); } /** * Creates a TopKCategoricalAccuracy metric * - * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param k Number of top elements to look at for computing accuracy. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type The data type for the metric result */ - public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { - super(tf, name, seed, type); + public TopKCategoricalAccuracy(String name, int k, long seed, Class type) { + super(name, seed, type); this.k = k; setLoss(this); } @@ -66,15 +91,22 @@ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class t /** * Computes how often targets are in the top {@code K} predictions. * + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param labels the truth values or labels * @param predictions the predictions + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return Top K categorical accuracy value. */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Metrics.topKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + public Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType) { + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); + return cast(tf, Metrics.topKCategoricalAccuracy(tf, tLabels, tPredictions, k), resultType); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index fd6b95df6d2..4110d988bcc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -16,7 +16,6 @@ import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** @@ -37,19 +36,17 @@ public class TrueNegatives extends ConfusionMatrixConditionCo * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name and a * default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TrueNegatives(Ops tf, long seed, Class type) { - this(tf, null, DEFAULT_THRESHOLD, seed, type); + public TrueNegatives(long seed, Class type) { + this(null, DEFAULT_THRESHOLD, seed, type); } /** * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -58,14 +55,13 @@ public TrueNegatives(Ops tf, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TrueNegatives(Ops tf, float threshold, long seed, Class type) { - this(tf, null, new float[] {threshold}, seed, type); + public TrueNegatives(float threshold, long seed, Class type) { + this(null, new float[] {threshold}, seed, type); } /** * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -74,27 +70,25 @@ public TrueNegatives(Ops tf, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TrueNegatives(Ops tf, float[] thresholds, long seed, Class type) { - this(tf, null, thresholds, seed, type); + public TrueNegatives(float[] thresholds, long seed, Class type) { + this(null, thresholds, seed, type); } /** * Creates a TrueNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TrueNegatives(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_THRESHOLD, seed, type); + public TrueNegatives(String name, long seed, Class type) { + this(name, DEFAULT_THRESHOLD, seed, type); } /** * Creates a TrueNegatives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -104,14 +98,13 @@ public TrueNegatives(Ops tf, String name, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TrueNegatives(Ops tf, String name, float threshold, long seed, Class type) { - this(tf, name, new float[] {threshold}, seed, type); + public TrueNegatives(String name, float threshold, long seed, Class type) { + this(name, new float[] {threshold}, seed, type); } /** * Creates a TrueNegatives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -121,7 +114,7 @@ public TrueNegatives(Ops tf, String name, float threshold, long seed, Class t * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TrueNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { - super(tf, name, ConfusionMatrixEnum.TRUE_NEGATIVES, thresholds, seed, type); + public TrueNegatives(String name, float[] thresholds, long seed, Class type) { + super(name, ConfusionMatrixEnum.TRUE_NEGATIVES, thresholds, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index 90fe9142014..7df1a54a345 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -16,7 +16,6 @@ import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; -import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** @@ -37,19 +36,17 @@ public class TruePositives extends ConfusionMatrixConditionCo * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name and a * default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TruePositives(Ops tf, long seed, Class type) { - this(tf, null, DEFAULT_THRESHOLD, seed, type); + public TruePositives(long seed, Class type) { + this(null, DEFAULT_THRESHOLD, seed, type); } /** * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -58,14 +55,13 @@ public TruePositives(Ops tf, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TruePositives(Ops tf, float threshold, long seed, Class type) { - this(tf, null, new float[] {threshold}, seed, type); + public TruePositives(float threshold, long seed, Class type) { + this(null, new float[] {threshold}, seed, type); } /** * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * - * @param tf the TensorFlow Ops * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is * {@code true}, below is {@code false}). One metric value is generated for each threshold @@ -74,27 +70,25 @@ public TruePositives(Ops tf, float threshold, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TruePositives(Ops tf, float[] thresholds, long seed, Class type) { - this(tf, null, thresholds, seed, type); + public TruePositives(float[] thresholds, long seed, Class type) { + this(null, thresholds, seed, type); } /** * Creates a TruePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TruePositives(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_THRESHOLD, seed, type); + public TruePositives(String name, long seed, Class type) { + this(name, DEFAULT_THRESHOLD, seed, type); } /** * Creates a TruePositives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -104,14 +98,13 @@ public TruePositives(Ops tf, String name, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TruePositives(Ops tf, String name, float threshold, long seed, Class type) { - this(tf, name, new float[] {threshold}, seed, type); + public TruePositives(String name, float threshold, long seed, Class type) { + this(name, new float[] {threshold}, seed, type); } /** * Creates a TruePositives metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is @@ -121,7 +114,7 @@ public TruePositives(Ops tf, String name, float threshold, long seed, Class t * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables */ - public TruePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { - super(tf, name, ConfusionMatrixEnum.TRUE_POSITIVES, thresholds, seed, type); + public TruePositives(String name, float[] thresholds, long seed, Class type) { + super(name, ConfusionMatrixEnum.TRUE_POSITIVES, thresholds, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index cbff958fc6f..5ab22480760 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -21,11 +21,10 @@ import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; -import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.metrics.BaseMetric; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; @@ -35,21 +34,20 @@ * * @param The data type for the metric result */ -public abstract class ConfusionMatrixConditionCount extends Metric { +public abstract class ConfusionMatrixConditionCount extends BaseMetric { public static final String ACCUMULATOR = "accumulator"; public static final float DEFAULT_THRESHOLD = 0.5f; private final ConfusionMatrixEnum confusionMatrixCond; private final float[] thresholds; private final String accumulatorName; private final Class type; + private final Zeros zeros = new Zeros<>(); private Variable accumulator; - private Assign initializer; /** * Creates a ConfusionMatrixConditionCount type of Metric, using a threshold of {@link * #DEFAULT_THRESHOLD} * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate * @param seed the seed for random number generation. An initializer created with a given seed @@ -57,13 +55,12 @@ public abstract class ConfusionMatrixConditionCount extends M * @param type the data type for the variables */ public ConfusionMatrixConditionCount( - Ops tf, String name, ConfusionMatrixEnum confusionMatrixCond, long seed, Class type) { - this(tf, name, confusionMatrixCond, DEFAULT_THRESHOLD, seed, type); + String name, ConfusionMatrixEnum confusionMatrixCond, long seed, Class type) { + this(name, confusionMatrixCond, DEFAULT_THRESHOLD, seed, type); } /** * Creates a ConfusionMatrixConditionCount type of Metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with prediction @@ -74,19 +71,17 @@ public ConfusionMatrixConditionCount( * @param type the data type for the variables */ public ConfusionMatrixConditionCount( - Ops tf, String name, ConfusionMatrixEnum confusionMatrixCond, float threshold, long seed, Class type) { - this(tf, name, confusionMatrixCond, new float[] {threshold}, seed, type); + this(name, confusionMatrixCond, new float[] {threshold}, seed, type); } /** * Creates a ConfusionMatrixConditionCount type of Metric * - * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with prediction @@ -97,41 +92,32 @@ public ConfusionMatrixConditionCount( * @param type the data type for the variables */ public ConfusionMatrixConditionCount( - Ops tf, String name, ConfusionMatrixEnum confusionMatrixCond, float[] thresholds, long seed, Class type) { - super(tf, name, seed); + super(name, seed); accumulatorName = this.getVariableName(ACCUMULATOR); this.type = type; this.confusionMatrixCond = confusionMatrixCond; this.thresholds = thresholds; - init(); - } - - /** Initialize the metric */ - private void init() { - Shape variableShape = Shape.of(this.thresholds.length); - - Zeros zeros = new Zeros<>(); - accumulator = - getTF() - .withName(getAccumulatorName()) - .withInitScope() - .variable(zeros.call(getTF(), getTF().constant(variableShape), type)); - initializer = - getTF().assign(accumulator, zeros.call(getTF(), getTF().constant(variableShape), type)); } - /** - * Gets the initializer for the accumulator variable - * - * @return the initializer for the accumulator variable - */ - public Assign getInitializer() { - return initializer; + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized()) { + setTF(tf); + Shape variableShape = Shape.of(this.thresholds.length); + + accumulator = + tf.withName(getAccumulatorName()) + .withInitScope() + .variable(zeros.call(tf.withInitScope(), tf.constant(variableShape), type)); + setInitialized(true); + } } /** @@ -145,18 +131,18 @@ public Assign getInitializer() { */ @Override public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Ops tf = getTF(); + init(tf); Operand tLabels = cast(tf, labels, type); Operand tPredictions = cast(tf, predictions, type); Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; return new ArrayList<>( MetricsHelper.updateConfusionMatrixVariables( - getTF(), + tf, Collections.singletonMap(confusionMatrixCond, accumulator), - Collections.singletonMap(confusionMatrixCond, initializer), tLabels, tPredictions, tf.constant(thresholds), @@ -169,14 +155,17 @@ public List updateStateList( /** {@inheritDoc} */ @Override - public Operand result() { - return getTF().identity(accumulator); + public Operand result(Ops tf, Class type) { + init(tf); + return cast(tf, tf.identity(accumulator), type); } /** {@inheritDoc} */ @Override - public Op resetStates() { - return initializer; + public Op resetStates(Ops tf) { + init(tf); + return tf.withName(accumulatorName) + .assign(accumulator, zeros.call(tf, tf.constant(accumulator.shape()), type)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 76c21aebefc..3729cdd7265 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -15,21 +15,25 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * Interface for Metrics that wrap AbstractLoss functions. - * - * @param The data type of the predictions. - */ -public interface LossMetric { +/** Interface for Metrics that wrap AbstractLoss functions. */ +public interface LossMetric { /** * Calculates the weighted loss between {@code labels} and {@code predictions} * + * @param tf the TensorFlow Ops * @param labels the truth values or labels * @param predictions the predictions + * @param resultType the data type for the result + * @param The data type of the predictions. * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call( + Ops tf, + Operand labels, + Operand predictions, + Class resultType); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanBaseMetricWrapper.java similarity index 82% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanBaseMetricWrapper.java index d9f4bb60cba..e53368cfda9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanBaseMetricWrapper.java @@ -34,22 +34,21 @@ * * @param The data type for the metric result */ -public class MeanMetricWrapper extends Mean { +public class MeanBaseMetricWrapper extends Mean { /** The loss function interface */ - protected LossMetric loss; + protected LossMetric loss; /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN} * - * @param tf the TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ - protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, type); + protected MeanBaseMetricWrapper(String name, long seed, Class type) { + super(name, seed, type); } /** @@ -57,7 +56,7 @@ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { * * @return the loss function. */ - public LossMetric getLoss() { + public LossMetric getLoss() { return loss; } @@ -66,7 +65,7 @@ public LossMetric getLoss() { * * @param loss the loss function. */ - protected void setLoss(LossMetric loss) { + protected void setLoss(LossMetric loss) { this.loss = loss; } @@ -84,9 +83,12 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return a List of control operations that updates the Mean state variables. */ public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -94,11 +96,12 @@ public List updateStateList( throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + init(tf); + Operand tLabels = cast(tf, labels, getInternalType()); + Operand tPredictions = cast(tf, predictions, getInternalType()); - Operand losses = loss.call(tLabels, tPredictions); + Operand losses = loss.call(tf, tLabels, tPredictions, getInternalType()); - return super.updateStateList(cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList(tf, cast(tf, losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 7d265ef7651..70a81da8d1e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -24,15 +24,16 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.framework.utils.SparseTensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.OneHot; import org.tensorflow.op.core.Rank; import org.tensorflow.op.core.Squeeze; @@ -50,7 +51,7 @@ /** * These are helper methods for Metrics and will be module private when Java modularity is applied - * to TensorFlow Java. These methods should not be used outside of the metrics packages. + * to TensorFlow Java. These methods should not be used outside the metrics packages. */ public class MetricsHelper { public static final float NEG_INF = -1e10f; @@ -64,12 +65,14 @@ public class MetricsHelper { * scalar, or the same rank as the target values, with each dimension either 1, or the same as the * corresponding values dimension. * - * @param tf the TensorFlow Ops + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param sampleWeights the sample weights. * @param values the values to which weights are applied. + * @param the data type for the parameters and result * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} can be * broadcast to {@code values} - * @param the type of Operand + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an * incorrect shape that prohibit broadcasting to {@code values} */ @@ -154,12 +157,14 @@ public static Op assertBroadcastable( /** * Gets an operand that tests if the shapes have the same rank and valid dimensions. * - * @param tf the TensorFlow Ops + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param weightsRank the operand for the rank of the sample weights * @param weightsShape the operand for the shape of the sample weights * @param valuesRank the operand for the rank of the sample weights * @param valuesShape the operand for the shape of the sample weights * @param the data type for the operands + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return a boolean operand to determine if the Shape is scalar or not. */ private static Operand canBroadcastNonscalarShapes( @@ -176,10 +181,12 @@ private static Operand canBroadcastNonscalarShapes( /** * Gets an operand that tests if the shapes have valid dimensions or not. * - * @param tf the TensorFlow Ops + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param weightsShape the operand for the shape of the sample weights * @param valuesShape the operand for the shape of the values * @param the data type for the operands + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return a boolean operand to determine if the shapes have valid dimensions or not. */ private static Operand canBroadcastDims( @@ -190,7 +197,8 @@ private static Operand canBroadcastDims( tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); - Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + FrameworkOps fops = FrameworkOps.create(tf); + Operand diffResult = fops.sets.difference(weightsShape2D, validDims); Operand numInvalidDims = tf.size(diffResult); return tf.math.equal(tf.constant(0), numInvalidDims); } @@ -198,10 +206,12 @@ private static Operand canBroadcastDims( /** * Broadcast {@code weights} to the same shape as {@code values}. * - * @param tf the TensorFlow ops + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. the TensorFlow ops * @param weights Operand whose shape is broadcastable to {@code values}. * @param values Operand of any shape * @param the type of Operands + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( @@ -306,14 +316,12 @@ public static List assertShapes( * LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code sampleWeight} is then * broadcast to the shape of {@code predictions}. * - * @param tf the TensorFlow Ops + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding variables to update as values. If {@code multiLabel}, then the variable * shapes are (T, D), where T is the number of thresholds and D is the number of classes * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then the * variable shapes are (T). - * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding initializer Operands to for {@code variablesToUpdate}. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. @@ -338,6 +346,8 @@ public static List assertShapes( * the 0th dimension (the examples dimension) of {@code predictions}. May be null. Must be * null if {@code multiLabel}. * @param the data type for the variables + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @throws IllegalArgumentException If {@code predictions} and {@code labels} have mismatched * shapes, or if {@code sampleWeight} is not null and its shape doesn't match {@code * predictions}, or if {@code multiLabel && labelWeights != null}.. @@ -347,7 +357,6 @@ public static List assertShapes( public static List updateConfusionMatrixVariables( Ops tf, Map> variablesToUpdate, - Map> varInitializers, Operand labels, Operand predictions, Operand thresholds, @@ -594,13 +603,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand[] op = loopVars.get(c); // op[0] = label, op[1] == prediction controlOps.add( - weightedAssignAdd( - tf, - op[0], - op[1], - weightsTiledF, - variablesToUpdate.get(c), - varInitializers.get(c))); + weightedAssignAdd(tf, op[0], op[1], weightsTiledF, variablesToUpdate.get(c))); } }); @@ -611,13 +614,14 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), * Creates an Operand that adds the values by taking the logical and of labels and predictions to * the specified confusion matrix variable. * - * @param tf The TensorFlow Ops + * @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. The TensorFlow Ops * @param labels the labels * @param predictions the predictions * @param weights the weights applied to the logical and result, may be null * @param variable the variable to update - * @param initializer the variable initializer to be applied to the variable, may be null. * @param the data type for the variable. + * @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph + * environment. * @return an Operand that updates the variable. */ private static Operand weightedAssignAdd( @@ -625,8 +629,7 @@ private static Operand weightedAssignAdd( Operand labels, Operand predictions, Operand weights, - Variable variable, - Assign initializer) { + Variable variable) { Class type = variable.type(); Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); @@ -638,16 +641,7 @@ private static Operand weightedAssignAdd( // else: // sum across ND, leaving shape (T) Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); - Operand assignAdd; - if (initializer != null) { - Ops tfc = - tf.withSubScope("weightedAssignAdd") - .withControlDependencies(Collections.singletonList(initializer)); - assignAdd = tfc.assignAdd(variable, valueSum); - } else { - assignAdd = tf.assignAdd(variable, valueSum); - } - return assignAdd; + return tf.assignAdd(variable, valueSum); } /** @@ -656,7 +650,7 @@ private static Operand weightedAssignAdd( *

Used for computing top-k prediction values in dense labels (which has the same shape as * predictions) for recall and precision top-k metrics. * - * @param tf The TensorFlow Ops + * @param tf the TensorFlow Ops * @param x the tensor with any dimensions to filter * @param topK the number of values to keep. * @param the data type for x and the return value. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index b96d2dfa1d0..f304d81fc71 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -14,35 +14,29 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.ArrayList; +import java.util.List; import org.tensorflow.Operand; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; -import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.metrics.BaseMetric; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - -/** - * Encapsulates metrics that perform a reduce operation on the metric values. - * - * @param The data type for the metric result - */ -public abstract class Reduce extends Metric { +/** Encapsulates metrics that perform a reduce operation on the metric values. */ +public abstract class Reduce extends BaseMetric { public static final String TOTAL = "total"; public static final String COUNT = "count"; protected final MetricReduction reduction; private final String totalName; private final String countName; - private final Class resultType; + private final Class internalType; /** the variable that holds the total of the metric values */ protected Variable total; /** @@ -54,55 +48,68 @@ public abstract class Reduce extends Metric { /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} * - * @param tf the TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param resultType the type for the variables and result + * @param internalType the type for the internal variables */ - protected Reduce(Ops tf, String name, long seed, Class resultType) { - this(tf, name, MetricReduction.SUM, seed, resultType); + protected Reduce(String name, long seed, Class internalType) { + this(name, MetricReduction.SUM, seed, internalType); } /** - * @param tf The TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param reduction The type of metric reduction to apply * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param resultType the type for the variables and result + * @param internalType the type for the internal variables */ - protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class resultType) { - super(tf, name, seed); + protected Reduce(String name, MetricReduction reduction, long seed, Class internalType) { + super(name, seed); this.reduction = reduction; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); - this.resultType = resultType; - setupVars(); + this.internalType = internalType; } - /** Initializes the Variables */ - private void setupVars() { - if (total == null) { - total = getTF().withName(totalName).variable(Shape.scalar(), resultType); - } - if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE - || reduction == MetricReduction.WEIGHTED_MEAN) { - if (count == null) { - count = getTF().withName(countName).variable(Shape.scalar(), resultType); + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + + if (!isInitialized()) { + setTF(tf); + Operand zero = cast(tf, tf.constant(0), internalType); + if (total == null) { + total = tf.withInitScope().withName(totalName).variable(zero); + } + if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE + || reduction == MetricReduction.WEIGHTED_MEAN) { + if (count == null) { + count = tf.withInitScope().withName(countName).variable(zero); + } } + setInitialized(true); } } /** {@inheritDoc} */ - public Op resetStates() { - List controls = new ArrayList<>(); + @Override + public Op resetStates(Ops tf) { + if (!isInitialized()) { + init(tf); + } + List operandList = new ArrayList<>(); if (total != null) { - controls.add(getTF().assign(total, cast(getTF(), getTF().constant(0), total.type()))); + operandList.add(tf.assign(total, cast(tf, tf.constant(0), total.type()))); } if (count != null) { - controls.add(getTF().assign(count, cast(getTF(), getTF().constant(0), count.type()))); + operandList.add(tf.assign(count, cast(tf, tf.constant(0), count.type()))); + } + if (operandList.size() == 1) { + return operandList.get(0); + } else { + return tf.withControlDependencies(operandList).noOp(); } - return getTF().withControlDependencies(controls).noOp(); } /** @@ -116,27 +123,27 @@ public Op resetStates() { */ @Override public List updateStateList( - Operand values, Operand sampleWeights) { + Ops tf, Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); } - Ops tf = getTF(); + init(tf); List updateOperations = new ArrayList<>(); // cast everything to match the variables Operand tSampleWeights = null; - Operand tValues = cast(tf, values, getResultType()); + Operand tValues = cast(tf, values, getInternalType()); if (sampleWeights != null) { - tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + tSampleWeights = cast(tf, sampleWeights, getInternalType()); // Update dimensions of weights to match with values if possible. LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); + LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); tValues = tuple.getTarget(); tSampleWeights = tuple.getSampleWeights(); try { // Broadcast weights if possible - tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); + tSampleWeights = MetricsHelper.broadcastWeights(tf, tSampleWeights, tValues); } catch (IllegalArgumentException ex) { // reduce values to same ndim as weight array // if we get here we have static shapes with either @@ -150,19 +157,18 @@ public List updateStateList( int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { - tValues = getTF().reduceSum(tValues, getTF().constant(axes)); + tValues = tf.reduceSum(tValues, tf.constant(axes)); } else { - tValues = getTF().math.mean(tValues, getTF().constant(axes)); + tValues = tf.math.mean(tValues, tf.constant(axes)); } } } - tValues = getTF().math.mul(tValues, tSampleWeights); + tValues = tf.math.mul(tValues, tSampleWeights); } Operand weightedValueSum = - getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); - Operand totalUpdate = - getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); + tf.reduceSum(tValues, LossesHelper.allAxes(tf, tValues)); + Operand totalUpdate = tf.assignAdd(total, cast(tf, weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; // Exit early if the reduction doesn't have a denominator. @@ -170,18 +176,17 @@ public List updateStateList( // Update `count` for reductions that require a denominator. switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); + numValues = cast(tf, tf.constant(tValues.shape().size()), internalType); break; case WEIGHTED_MEAN: if (tSampleWeights == null) { - numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); + numValues = cast(tf, tf.constant(tValues.shape().size()), internalType); } else { numValues = cast( - getTF(), - getTF() - .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), - resultType); + tf, + tf.reduceSum(tSampleWeights, LossesHelper.allAxes(tf, tSampleWeights)), + internalType); } break; default: @@ -189,7 +194,7 @@ public List updateStateList( String.format("reduction [%s] not implemented", reduction)); } - Operand totalCount = getTF().assignAdd(this.count, numValues); + Operand totalCount = tf.assignAdd(this.count, numValues); updateOperations.add(totalCount); } @@ -199,16 +204,16 @@ public List updateStateList( /** {@inheritDoc} */ @Override - public Operand result() { - Operand fResult; - + public Operand result(Ops tf, Class type) { + Operand fResult; + init(tf); switch (this.reduction) { case SUM: - fResult = getTF().identity(total); + fResult = cast(tf, tf.identity(total), type); break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, cast(getTF(), count, resultType)); + fResult = cast(tf, tf.math.divNoNan(total, count), type); break; default: throw new UnsupportedOperationException( @@ -240,7 +245,7 @@ public Variable getCount() { * * @return the type for the variables */ - public Class getResultType() { - return resultType; + public Class getInternalType() { + return internalType; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 870579ad636..376ac4f4388 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -3,26 +3,24 @@ import static org.tensorflow.framework.utils.CastHelper.cast; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; -import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.metrics.BaseMetric; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; /** * Abstract base class for computing sensitivity and specificity. * - * @param The data type for the metric result + * @param The data internalType for the metric result */ -public abstract class SensitivitySpecificityBase extends Metric { +public abstract class SensitivitySpecificityBase extends BaseMetric { public static final int DEFAULT_NUM_THRESHOLDS = 200; @@ -37,33 +35,30 @@ public abstract class SensitivitySpecificityBase extends Metr private final String falsePositivesName; private final String trueNegativesName; private final String falseNegativesName; - private final Class type; + private final Zeros zeros = new Zeros<>(); + private final Class internalType; protected Variable truePositives; protected Variable falsePositives; protected Variable trueNegatives; protected Variable falseNegatives; - private Assign truePositivesInitializer; - private Assign falsePositivesInitializer; - private Assign trueNegativesInitializer; - private Assign falseNegativesInitializer; - /** * Creates a SensitivitySpecificityBase Metric * - * @param tf the TensorFlow Ops * @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used * @param numThresholds The number of thresholds to use for matching the given recall. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the data type for the variables + * will always produce the same random tensor for a given shape and data internalType. + * @param internalType the data internalType for the variables * @throws IllegalArgumentException if numThresholds <= 0. */ protected SensitivitySpecificityBase( - Ops tf, String name, int numThresholds, long seed, Class type) { - super(tf, name, seed); - if (numThresholds <= 0) throw new IllegalArgumentException("numThresholds must be > 0."); - this.type = type; + String name, int numThresholds, long seed, Class internalType) { + super(name, seed); + if (numThresholds <= 0) { + throw new IllegalArgumentException("numThresholds must be > 0."); + } + this.internalType = internalType; this.truePositivesName = this.getVariableName(TRUE_POSITIVES); this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); @@ -80,60 +75,35 @@ protected SensitivitySpecificityBase( } this.thresholds[numThresholds - 1] = 1f; } - init(); } - /** Initializes the Variables */ - private void init() { - Ops tf = getTF(); - Zeros zeros = new Zeros<>(); - Shape varShape = Shape.of(numThresholds); - Operand zero = zeros.call(tf, tf.constant(varShape), type); - - if (this.getTruePositives() == null) { - - truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); - truePositivesInitializer = tf.assign(truePositives, zero); - } - if (this.getFalsePositives() == null) { + /** {@inheritDoc} */ + @Override + protected void init(Ops tf) { + checkIsGraph(tf); + if (!isInitialized()) { + setTF(tf); + Shape varShape = Shape.of(numThresholds); + Operand zero = zeros.call(tf, tf.constant(varShape), internalType); - falsePositives = tf.withName(falsePositivesName).withInitScope().variable(zero); - falsePositivesInitializer = tf.assign(falsePositives, zero); - } - if (this.getTrueNegatives() == null) { + if (this.getTruePositives() == null) { - trueNegatives = tf.withInitScope().withName(trueNegativesName).variable(zero); - trueNegativesInitializer = tf.assign(trueNegatives, zero); - } - if (this.getFalseNegatives() == null) { + truePositives = tf.withName(truePositivesName).withInitScope().variable(zero); + } + if (this.getFalsePositives() == null) { - falseNegatives = tf.withInitScope().withName(falseNegativesName).variable(zero); - falseNegativesInitializer = tf.assign(falseNegatives, zero); - } - } + falsePositives = tf.withName(falsePositivesName).withInitScope().variable(zero); + } + if (this.getTrueNegatives() == null) { - /** - * Gets a control dependency Op to initialize all the variables - * - * @return a control dependency Op to initialize all the variables - */ - public Op initializeVariables() { - List varInitializers = new ArrayList<>(); + trueNegatives = tf.withInitScope().withName(trueNegativesName).variable(zero); + } + if (this.getFalseNegatives() == null) { - if (truePositivesInitializer != null) { - varInitializers.add(truePositivesInitializer); - } - if (falsePositivesInitializer != null) { - varInitializers.add(falsePositivesInitializer); - } - if (trueNegativesInitializer != null) { - varInitializers.add(trueNegativesInitializer); - } - if (falseNegativesInitializer != null) { - varInitializers.add(falseNegativesInitializer); + falseNegatives = tf.withInitScope().withName(falseNegativesName).variable(zero); + } + setInitialized(true); } - - return getTF().withControlDependencies(varInitializers).noOp(); } /** @@ -146,15 +116,16 @@ public Op initializeVariables() { * @return a List of Operations to update the metric state. */ @Override - @SuppressWarnings("unchecked") public List updateStateList( + Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Ops tf = getTF(); - Operand tLabels = cast(tf, labels, type); - Operand tPredictions = cast(tf, predictions, type); - Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + init(tf); + Operand tLabels = cast(tf, labels, internalType); + Operand tPredictions = cast(tf, predictions, internalType); + Operand tSampleWeights = + sampleWeights != null ? cast(tf, sampleWeights, internalType) : null; Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTruePositives()); @@ -165,7 +136,6 @@ public List updateStateList( return MetricsHelper.updateConfusionMatrixVariables( tf, confusionMatrix, - Collections.EMPTY_MAP, tLabels, tPredictions, tf.constant(thresholds), @@ -178,8 +148,29 @@ public List updateStateList( /** {@inheritDoc} */ @Override - public Op resetStates() { - return initializeVariables(); + public Op resetStates(Ops tf) { + + Shape varShape = Shape.of(numThresholds); + Operand zero = zeros.call(tf, tf.constant(varShape), internalType); + + List controlList = new ArrayList<>(); + if (this.getTruePositives() != null) { + controlList.add(tf.withName(truePositivesName).assign(this.getTruePositives(), zero)); + } + if (this.getFalsePositives() != null) { + controlList.add(tf.withName(falsePositivesName).assign(this.getFalsePositives(), zero)); + } + if (this.getTrueNegatives() != null) { + controlList.add(tf.withName(trueNegativesName).assign(this.getTrueNegatives(), zero)); + } + if (this.getFalseNegatives() != null) { + controlList.add(tf.withName(falseNegativesName).assign(this.getFalseNegatives(), zero)); + } + if (controlList.size() == 1) { + return controlList.get(0); + } else { + return tf.withControlDependencies(controlList).noOp(); + } } /** @@ -273,11 +264,15 @@ public String getFalseNegativesName() { } /** - * Gets the type + * Gets the internalType * - * @return the type + * @return the internalType */ public Class getType() { - return type; + return internalType; + } + + public Class getInternalType() { + return internalType; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java deleted file mode 100644 index dd77a1be4aa..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import static org.tensorflow.framework.utils.CastHelper.cast; - -import org.tensorflow.Operand; -import org.tensorflow.op.Ops; -import org.tensorflow.op.SparseOps; -import org.tensorflow.op.sparse.DenseToDenseSetOperation; -import org.tensorflow.types.family.TNumber; - -/** Implementation of set operations */ -public class SetsOps { - - /** - * Computes set difference of elements in last dimension of {@code a} and {@code b} with {@code - * aMinusB} set to true. - * - *

All but the last dimension of {@code a} and {@code b} must match - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last - * dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand difference(Ops tf, Operand a, Operand b) { - return difference(tf, a, b, true); - } - - /** - * Computes set difference of elements in last dimension of {@code a} and {@code b}. - * - *

All but the last dimension of {@code a} and {@code b} must match - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param aMinusB whether to subtract b from a, vs vice versa. - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last - * dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand difference( - Ops tf, Operand a, Operand b, boolean aMinusB) { - return setOperation(tf, a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); - } - - /** - * Computes set union of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last - * dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand union(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.UNION); - } - - /** - * Computes set intersection of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last - * dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand intersection(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.INTERSECTION); - } - - /** - * Compute set operation of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first set operation operand - * @param b The other et operation operand - * @param setOperation The set operation to perform, {@link Operation}. - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the last - * dimension the same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand setOperation( - Ops tf, Operand a, Operand b, Operation setOperation) { - - DenseToDenseSetOperation setOperationResult = - tf.sparse.denseToDenseSetOperation( - a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - - return tf.sparse.sparseToDense( - setOperationResult.resultIndices(), - setOperationResult.resultShape(), - setOperationResult.resultValues(), - cast(tf, tf.constant(0), a.type())); - } - - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 2df90a841ee..79dc6e765ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.List; import org.tensorflow.Operand; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -42,7 +43,7 @@ public class WeightsBroadcastOps { /** * Asserts that {@code weights} can be broadcast to {@code values} * - * @param tf the TensorFlow Ops + * @param tf the TensorFlow Ops. * @param weights the weights Operand * @param values Operand of values to which weights are applied. * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has @@ -118,7 +119,7 @@ public static Op assertBroadcastable( * Check to see that weights and values have the same rank, if they do, then check each * corresponding dim of each. * - * @param tf The TensorFlow Ops + * @param tf the TensorFlow Ops * @param weightsRank the rank operand for the weights * @param weightsShape the shape operand for the weights * @param valuesRank the rank operand for the values @@ -141,7 +142,7 @@ private static Operand hasValidNonscalarShape( * Checks that each dimension of the two shapes are the same size, or that the weight dimension * size is 1. * - * @param tf the TensorFlow Ops + * @param tf the TensorFlow Ops. * @param weightsShape the shape of the weights * @param valuesShape the shape of the values * @return a boolean Operand, true if all the dimensions of the two shapes are the same. @@ -155,7 +156,8 @@ private static Operand hasValidDims( tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); - Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); + FrameworkOps fops = FrameworkOps.create(tf); + Operand invalidDims = fops.sets.difference(weightsShape2d, validDims); Operand numInvalidDims = tf.size(invalidDims, TInt32.class); return tf.math.equal(tf.constant(0), numInvalidDims); } @@ -168,7 +170,7 @@ private static Operand hasValidDims( * When computing a weighted average, use this function to broadcast {@code weights} before * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * - * @param tf the TensorFlow ops + * @param tf the TensorFlow Ops * @param weights Operand whose shape is able to be broadcast to {@code values} * @param values Tensor` of any shape * @param the type of Operand diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index 7b6322d0f0d..f775b1873b2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -109,6 +109,10 @@ public final Scope scope() { * Returns an API that builds operations with the provided name prefix. * *

@link Scope#withSubScope(String)} + * + * @param childScopeName name for the new child scope + * @return a new FrameworkOps that uses the child sub scope + * @throws IllegalArgumentException if the name is invalid */ public FrameworkOps withSubScope(String childScopeName) { return new FrameworkOps(scope.withSubScope(childScopeName)); @@ -118,6 +122,10 @@ public FrameworkOps withSubScope(String childScopeName) { * Returns an API that uses the provided name for an op. * *

{@link Scope#withName(String)} + * + * @param opName name for an operator in the returned scope + * @return a new FrameworkOps that uses opName for operations. + * @throws IllegalArgumentException if the name is invalid */ public FrameworkOps withName(String opName) { return new FrameworkOps(scope.withName(opName)); @@ -146,4 +154,20 @@ public FrameworkOps withDevice(DeviceSpec deviceSpec) { public FrameworkOps withControlDependencies(Iterable controls) { return new FrameworkOps(scope.withControlDependencies(controls)); } + + /** + * Returns an FrameworkOps that builds init operations. + * + *

Init operations will be initialized at session creation, will have their inputs (and control + * inputs) made init ops as well, and are ignored when used as control dependencies. Additionally, + * this scope ignores any control dependencies. + * + *

If an input can not be made an init op (i.e. a Placeholder), will throw an {@link + * IllegalStateException} on op creation. + * + * @return a FrameworkOps with a scope that builds init operations + */ + public FrameworkOps withInitScope() { + return new FrameworkOps(scope.withInitScope()); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java index d8beacaa32f..3b02461e0d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java @@ -21,7 +21,6 @@ import org.tensorflow.op.dtypes.Cast; import org.tensorflow.op.math.Conj; import org.tensorflow.op.sparse.SparseMatMul; -import org.tensorflow.op.train.BatchMatMul; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -234,14 +233,16 @@ public static Operand matmul( // use adjoint instead. Conj() is a noop for real matrices. if (transposeA) { a = Conj.create(scope, a); - adjointA = true; } if (transposeB) { b = Conj.create(scope, b); - adjointB = true; } - return BatchMatMul.create( - lscope, a, b, a.type(), BatchMatMul.adjX(adjointA), BatchMatMul.adjY(adjointB)); + return org.tensorflow.op.linalg.MatMul.create( + lscope, + a, + b, + org.tensorflow.op.linalg.MatMul.transposeA(transposeA), + org.tensorflow.op.linalg.MatMul.transposeB(transposeB)); } // Neither matmul nor sparse_matmul support adjoint, so we conjugate diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java index a222f64679e..bbe509414c3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java @@ -396,6 +396,7 @@ private static Operand[] tensordotAxes( * *

In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. * + * @param scope current scope * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. * @param b {@code Operand} with the same type as {@code a}. * @param axis sum over the last N axes of a and the first N axes of b in order. If {@code @@ -447,6 +448,7 @@ public static Operand tensordot( * *

* + * @param scope current scope * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. * @param b {@code Operand} with the same type as {@code a}. * @param axes If axes is a scalar, sum over the last N axes of a and the first N axes of b in @@ -502,6 +504,7 @@ public static Operand tensordot( * *

* + * @param scope current scope * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. * @param b {@code Operand} with the same type as {@code a}. * @param axes the first and second row contain the set of unique integers specifying axes along @@ -555,6 +558,7 @@ public static Operand tensordot( * *

* + * @param scope current scope * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. * @param b {@code Operand} with the same type as {@code a}. * @param axes the first and second row contain the set of unique integers specifying axes along @@ -608,6 +612,7 @@ public static Operand tensordot( * *

* + * @param scope current scope * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. * @param b {@code Operand} with the same type as {@code a}. * @param aAxis axes for the a Operand diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/sets/Sets.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/sets/Sets.java index 0cbfa74e770..c91581632d8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/sets/Sets.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/sets/Sets.java @@ -31,6 +31,7 @@ public class Sets { * *

All but the last dimension of a and b must match * + * @param scope current scope * @param a The first operand representing set a * @param b The other operand representing set b * @param the data type for the sets @@ -47,6 +48,7 @@ public static Operand difference(Scope scope, Operand * *

All but the last dimension of a and b must match * + * @param scope current scope * @param a The first operand representing set a * @param b The other operand representing set b * @param aMinusB whether to subtract b from a, vs vice versa. @@ -63,6 +65,7 @@ public static Operand difference( /** * Computes set union of elements in last dimension of a and b. * + * @param scope current scope * @param a The first operand representing set a * @param b The other operand representing set b * @param the data type for the sets @@ -77,6 +80,7 @@ public static Operand union(Scope scope, Operand a, Op /** * Computes set intersection of elements in last dimension of a and b. * + * @param scope current scope * @param a The first operand representing set a * @param b The other operand representing set b * @param the data type for the sets @@ -92,6 +96,7 @@ public static Operand intersection( /** * Compute set operation of elements in last dimension of a and b. * + * @param scope current scope * @param a The first set operation operand * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index ae40074f3f6..e4f486b8a08 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -45,20 +45,18 @@ public void testValueIsIdempotent() { Ops tf = session.getTF(); Operand yPred = tf.constant(predArray); Operand yTrue = tf.constant(trueArray); - AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + AUC instance = new AUC<>(numThresholds, 1001L, TFloat32.class); - session.initialize(); - - Op update = instance.updateState(yTrue, yPred, null); + Op update = instance.updateState(tf, yTrue, yPred, null); for (int i = 0; i < 10; i++) { session.run(update); } - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); for (int i = 0; i < 10; i++) { - session.evaluate(result, instance.result()); + session.evaluate(result, instance.result(tf, TFloat32.class)); } } } @@ -75,7 +73,7 @@ public void testCumulative() { Ops tf = session.getTF(); Operand yPred = tf.constant(predArray); Operand yTrue = tf.constant(trueArray); - AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + AUC instance = new AUC<>(numThresholds, 1001L, TFloat32.class); session.initialize(); @@ -85,7 +83,7 @@ public void testCumulative() { assertNull(instance.getFalseNegatives()); for (int i = 0; i < 3; i++) { - Op update = instance.updateState(yTrue, yPred, null); + Op update = instance.updateState(tf, yTrue, yPred, null); session.run(update); session.evaluate(tp[i], instance.getTruePositives()); session.evaluate(fp[i], instance.getFalsePositives()); @@ -94,9 +92,9 @@ public void testCumulative() { } // test reset - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); for (int i = 0; i < 3; i++) { - Op update = instance.updateState(yTrue, yPred, null); + Op update = instance.updateState(tf, yTrue, yPred, null); session.run(update); session.evaluate(tp[i], instance.getTruePositives()); session.evaluate(fp[i], instance.getFalsePositives()); @@ -110,7 +108,7 @@ public void testCumulative() { public void basicTestSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + AUC instance = new AUC<>(numThresholds, 1001L, TFloat32.class); assertEquals(numThresholds, instance.getNumThresholds()); float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); @@ -119,9 +117,9 @@ public void basicTestSampleWeight() { Operand yTrue = tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}); Operand sampleWeights = tf.constant(new float[] {1, 0, 0, 1}); - Op update = instance.updateState(yTrue, yPred, sampleWeights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWeights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(1.0f, result); } } @@ -131,12 +129,12 @@ public void testUnweightedAllCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Operand yTrue = cast(tf, tf.constant(this.trueArray), TFloat32.class); - AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + AUC instance = new AUC<>(this.numThresholds, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yTrue, null); + Op update = instance.updateState(tf, yTrue, yTrue, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(1f, result); } @@ -148,11 +146,11 @@ public void testUnweighted() { Ops tf = session.getTF(); Operand yPred = tf.constant(this.predArray); Operand yTrue = tf.constant(this.trueArray); - AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + AUC instance = new AUC<>(this.numThresholds, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, null); + Op update = instance.updateState(tf, yTrue, yPred, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); // float expectedResult = (0.75f * 1 + 0.25f * 0); session.evaluate(0.75f, result); @@ -165,13 +163,13 @@ public void testManualThresholds() { Ops tf = session.getTF(); Operand yPred = tf.constant(this.predArray); Operand yTrue = tf.constant(this.trueArray); - AUC instance = new AUC<>(tf, new float[] {0.5f}, 1001L, TFloat32.class); + AUC instance = new AUC<>(new float[] {0.5f}, 1001L, TFloat32.class); float[] expectedThresholds = new float[] {-AUC.EPSILON, 0.5f, 1 + AUC.EPSILON}; assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); session.initialize(); - Op update = instance.updateState(yTrue, yPred, null); + Op update = instance.updateState(tf, yTrue, yPred, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); // float expectedResult = (0.75f * 1 + 0.25f * 0); session.evaluate(0.75f, result); @@ -186,11 +184,11 @@ public void testWeightedRocInterpolation() { Operand yTrue = tf.constant(this.trueArray); Operand sampleWights = tf.constant(this.sampleWeight); - AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + AUC instance = new AUC<>(this.numThresholds, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, sampleWights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = (0.78571427f * 1 + 0.2857145f * 0); session.evaluate(expectedResult, result); @@ -207,16 +205,11 @@ public void testWeightedRocMajoring() { AUC instance = new AUC<>( - tf, - this.numThresholds, - AUCCurve.ROC, - AUCSummationMethod.MAJORING, - 1001L, - TFloat32.class); + this.numThresholds, AUCCurve.ROC, AUCSummationMethod.MAJORING, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, sampleWights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = (1.0f + .5714285f * 0f); session.evaluate(expectedResult, result); @@ -233,16 +226,11 @@ public void testWeightedRocMinoring() { AUC instance = new AUC<>( - tf, - this.numThresholds, - AUCCurve.ROC, - AUCSummationMethod.MINORING, - 1001L, - TFloat32.class); + this.numThresholds, AUCCurve.ROC, AUCSummationMethod.MINORING, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, sampleWights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = (0.5714285f + 0f * 0f); session.evaluate(expectedResult, result); @@ -259,16 +247,11 @@ public void testWeightedPrMajoring() { AUC instance = new AUC<>( - tf, - this.numThresholds, - AUCCurve.PR, - AUCSummationMethod.MAJORING, - 1001L, - TFloat32.class); + this.numThresholds, AUCCurve.PR, AUCSummationMethod.MAJORING, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, sampleWights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = 0.4285715f + 0.5714285f; session.evaluate(expectedResult, result); } @@ -284,16 +267,11 @@ public void testWeightedPrMinoring() { AUC instance = new AUC<>( - tf, - this.numThresholds, - AUCCurve.PR, - AUCSummationMethod.MINORING, - 1001L, - TFloat32.class); + this.numThresholds, AUCCurve.PR, AUCSummationMethod.MINORING, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, sampleWights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = 0.7f * 0.4285715f + 0f * 0.5714285f; session.evaluate(expectedResult, result); } @@ -307,12 +285,11 @@ public void testWeightedPrInterpolation() { Operand yTrue = tf.constant(this.trueArray); Operand sampleWights = tf.constant(this.sampleWeight); - AUC instance = - new AUC<>(tf, this.numThresholds, AUCCurve.PR, 1001L, TFloat32.class); + AUC instance = new AUC<>(this.numThresholds, AUCCurve.PR, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(yTrue, yPred, sampleWights); + Op update = instance.updateState(tf, yTrue, yPred, sampleWights); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = 0.916613f; session.evaluate(expectedResult, result); } @@ -320,24 +297,13 @@ public void testWeightedPrInterpolation() { @Test public void testInvalidNumThresholds() { - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - - new AUC<>(tf, -1, 1001L, TFloat32.class); - } - }); + assertThrows(IllegalArgumentException.class, () -> new AUC<>(-1, 1001L, TFloat32.class)); } @Test public void testExtraDims() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - // logits = scipy.special.expit(-np.array([[[-10., 10., -10.], [10., -10., 10.]], - // [[-12., 12., -12.], [12., -12., 12.]]], - // dtype=np.float32)) float[][][] logitsArray = { { {9.99954602e-01f, 4.53978687e-05f, 9.99954602e-01f}, @@ -357,13 +323,26 @@ public void testExtraDims() { Operand logits = tf.constant(logitsArray); Operand labels = tf.constant(labelArray); - AUC instance = new AUC<>(tf, 1001L, TFloat32.class); + AUC instance = new AUC<>(1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(labels, logits, null); + Op update = instance.updateState(tf, labels, logits, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expectedResult = 0.5f; session.evaluate(expectedResult, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + AUC instance = new AUC<>(1001L, TFloat32.class); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + assertThrows( + IllegalArgumentException.class, () -> instance.updateState(tf, yTrue, yPred, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java index 48cac95b8a6..1af3100641b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -31,18 +33,17 @@ public class AccuracyTest { public void testCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Accuracy instance = new Accuracy<>(1001L, TFloat32.class); int[] trueArray = {1, 2, 3, 4}; float[] predArray = {1, 2, 3, 4}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(4F, total); session.evaluate(4, count); session.evaluate(1F, result); @@ -53,8 +54,7 @@ public void testCorrect() { public void testSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Accuracy instance = new Accuracy<>(1001L, TFloat32.class); float[] trueArray = {2, 1}; float[] predArray = {2, 0}; @@ -64,12 +64,12 @@ public void testSampleWeight() { Operand sampleWeight = tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(.5F, total); session.evaluate(.7, count); session.evaluate(0.71428573f, result); @@ -80,51 +80,59 @@ public void testSampleWeight() { public void testVariableState() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Accuracy instance = new Accuracy<>(1001L, TFloat32.class); float[] trueArray = {2, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); Operand sampleWeight = tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, labels, sampleWeight); + Op op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.7F, total); session.evaluate(.7, count); session.evaluate(1.0F, result); // 2nd run - op = instance.updateState(labels, labels, sampleWeight); + op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(1.4F, total); session.evaluate(1.4, count); session.evaluate(1.0F, result); // new instance same graph - instance = new Accuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); - op = instance.updateState(labels, labels, sampleWeight); + instance = new Accuracy<>(1001L, TFloat32.class); + op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); total = instance.getTotal(); count = instance.getCount(); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(0.7F, total); session.evaluate(.7, count); session.evaluate(1.0F, result); // reset variables - session.run(instance.resetStates()); - result = instance.result(); - op = instance.updateState(labels, labels, sampleWeight); + session.run(instance.resetStates(tf)); + result = instance.result(tf, TFloat32.class); + op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); session.evaluate(0.7F, total); session.evaluate(.7, count); session.evaluate(1.0F, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java index d203815f4ab..691c2ab9cce 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -31,18 +33,17 @@ public class BinaryAccuracyTest { public void testCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + BinaryAccuracy instance = new BinaryAccuracy<>(1001L, TFloat32.class); int[] trueArray = {1, 0}; float[] predArray = {1, 0}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(2F, total); session.evaluate(2, count); session.evaluate(1F, result); @@ -53,18 +54,17 @@ public void testCorrect() { public void testPredictionSqueeze() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + BinaryAccuracy instance = new BinaryAccuracy<>(1001L, TFloat32.class); int[] trueArray = {1, 0}; float[] predArray = {1, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1, 1))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(2F, total); session.evaluate(4, count); session.evaluate(0.5F, result); @@ -75,8 +75,8 @@ public void testPredictionSqueeze() { public void testSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + BinaryAccuracy instance = new BinaryAccuracy<>(1001L, TFloat32.class); + int[] trueArray = {1, 1}; float[] predArray = {1, 0}; @@ -86,11 +86,11 @@ public void testSampleWeight() { Operand sampleWeight = tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.5F, total); session.evaluate(.7, count); session.evaluate(0.71428573f, result); @@ -101,50 +101,50 @@ public void testSampleWeight() { public void testVariableState() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + BinaryAccuracy instance = new BinaryAccuracy<>(1001L, TFloat32.class); + float[] trueArray = {2, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); Operand sampleWeight = tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, labels, sampleWeight); + Op op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.2F, total); session.evaluate(.7, count); session.evaluate(0.2857143F, result); // 2nd run - op = instance.updateState(labels, labels, sampleWeight); + op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(0.4F, total); session.evaluate(1.4, count); session.evaluate(0.2857143F, result); // new instance same graph - instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); - op = instance.updateState(labels, labels, sampleWeight); + instance = new BinaryAccuracy<>(1001L, TFloat32.class); + + op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); total = instance.getTotal(); count = instance.getCount(); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(0.2F, total); session.evaluate(.7, count); session.evaluate(0.2857143F, result); // reset variables - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); session.evaluate(0.0, total); session.evaluate(0.0, count); - op = instance.updateState(labels, labels, sampleWeight); + op = instance.updateState(tf, labels, labels, sampleWeight); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(0.2F, total); session.evaluate(.7, count); session.evaluate(0.2857143F, result); @@ -155,22 +155,32 @@ public void testVariableState() { public void testBinaryAccuracyAThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + BinaryAccuracy instance = new BinaryAccuracy<>(0.7f, 1001L, TFloat32.class); + int[] trueArray = {1, 1, 0, 0}; float[] predArray = {0.9f, 0.6f, 0.4f, 0.8f}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(2F, total); session.evaluate(4, count); session.evaluate(0.5F, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index be46bb5c282..8ca0c12462d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -33,18 +35,18 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); BinaryCrossentropy instance = - new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new BinaryCrossentropy<>("BCE_testUnweighted", false, 0, 1001L, TFloat64.class); + float[] trueArray = {1, 0, 1, 0}; float[] predictionArray = {1, 1, 1, 0}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); - Op op = instance.updateState(labels, yPrediction, null); + Op op = instance.updateState(tf, labels, yPrediction, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(7.71247434F, total); session.evaluate(2, count); session.evaluate(3.85623717F, result); @@ -56,17 +58,17 @@ public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); BinaryCrossentropy instance = - new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new BinaryCrossentropy<>("BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); + float[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, logits, null); + Op op = instance.updateState(tf, labels, logits, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(66.66667, total); session.evaluate(2, count); session.evaluate(33.333332, result); @@ -78,20 +80,20 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); BinaryCrossentropy instance = - new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new BinaryCrossentropy<>("BCE_testWeighted", false, 0, 1001L, TFloat32.class); + int[] trueArray = {1, 0, 1, 0}; float[] predictionArray = {1, 1, 1, 0}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(new float[] {1.5f, 2.f}); - Op op = instance.updateState(labels, yPrediction, sampleWeight); + Op op = instance.updateState(tf, labels, yPrediction, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(11.499929f, total); session.evaluate(3.5f, count); session.evaluate(3.285694f, result); @@ -103,20 +105,20 @@ public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); BinaryCrossentropy instance = - new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new BinaryCrossentropy<>("BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); + float[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {2, 2.5}); - Op op = instance.updateState(labels, logits, sampleWeight); + Op op = instance.updateState(tf, labels, logits, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(166.66666, total); session.evaluate(4.5, count); session.evaluate(37.037033, result); @@ -130,22 +132,35 @@ public void testLabelSmoothing() { float labelSmoothing = 0.1F; BinaryCrossentropy instance = new BinaryCrossentropy<>( - tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); - session.run(instance.resetStates()); + "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); + float[] trueArray = {1, 0, 1}; double[] logitsArray = {100., -100., -100.}; Operand labels = tf.constant(trueArray); Operand logits = tf.constant(logitsArray); - Op op = instance.updateState(labels, logits, null); + Op op = instance.updateState(tf, labels, logits, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(35, total); session.evaluate(1, count); session.evaluate(35, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + float labelSmoothing = 0.1F; + BinaryCrossentropy instance = + new BinaryCrossentropy<>( + "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java index aea2e4e0d6e..c57b39db158 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -31,8 +33,7 @@ public class CategoricalAccuracyTest { public void testCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + CategoricalAccuracy instance = new CategoricalAccuracy<>(1001L, TFloat32.class); int[] trueArray = { 0, 0, 1, 0, 1, 0 @@ -44,11 +45,11 @@ public void testCorrect() { Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(2F, total); session.evaluate(2, count); session.evaluate(1F, result); @@ -59,8 +60,7 @@ public void testCorrect() { public void testSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + CategoricalAccuracy instance = new CategoricalAccuracy<>(1001L, TFloat32.class); int[] trueArray = { 0, 0, 1, 0, 1, 0 @@ -75,11 +75,11 @@ public void testSampleWeight() { Operand sampleWeight = tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.7F, total); session.evaluate(.7, count); session.evaluate(1.0F, result); @@ -90,8 +90,7 @@ public void testSampleWeight() { public void testVariableState() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + CategoricalAccuracy instance = new CategoricalAccuracy<>(1001L, TFloat32.class); int[] trueArray = { 0, 0, 1, 0, 1, 0 @@ -107,29 +106,28 @@ public void testVariableState() { Operand sampleWeight = tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.7F, total); session.evaluate(.7, count); session.evaluate(1.0F, result); // 2nd run - op = instance.updateState(labels, predictions, sampleWeight); + op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(1.4F, total); session.evaluate(1.4, count); session.evaluate(1.0F, result); // new instance same graph - instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); - op = instance.updateState(labels, predictions, sampleWeight); + instance = new CategoricalAccuracy<>(1001L, TFloat32.class); + op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat32.class); total = instance.getTotal(); count = instance.getCount(); session.evaluate(0.7F, total); @@ -137,17 +135,28 @@ public void testVariableState() { session.evaluate(1.0F, result); // reset variables - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); session.evaluate(0, total); session.evaluate(0, count); session.evaluate(0, result); - op = instance.updateState(labels, predictions, sampleWeight); + op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat32.class); session.evaluate(0.7F, total); session.evaluate(.7, count); session.evaluate(1.0F, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + float labelSmoothing = 0.1F; + CategoricalAccuracy instance = new CategoricalAccuracy<>(1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java index 34fc3eef884..fb77efa9fa7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -32,19 +34,17 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CategoricalCrossentropy instance = - new CategoricalCrossentropy<>( - tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new CategoricalCrossentropy<>("CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); int[] trueArray = {0, 1, 0, 0, 0, 1}; double[] predArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(2.3538785, total); session.evaluate(2, count); session.evaluate(1.1769392, result); @@ -57,18 +57,17 @@ public void testUnweightedLogits() { Ops tf = session.getTF(); CategoricalCrossentropy instance = new CategoricalCrossentropy<>( - tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); int[] trueArray = {0, 1, 0, 0, 0, 1}; double[] predArray = {1, 9, 0, 1, 8, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(7.0022807, total); session.evaluate(2, count); session.evaluate(3.5011404, result); @@ -80,20 +79,18 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CategoricalCrossentropy instance = - new CategoricalCrossentropy<>( - tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new CategoricalCrossentropy<>("CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); int[] trueArray = {0, 1, 0, 0, 0, 1}; double[] predArray = {0.05f, 0.95f, 0f, 0.1f, 0.8f, 0.1f}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {1.5f, 2.}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(4.6821095, total); session.evaluate(3.5, count); session.evaluate(1.3377455, result); @@ -105,19 +102,18 @@ public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CategoricalCrossentropy instance = - new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new CategoricalCrossentropy<>("CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); int[] trueArray = {0, 1, 0, 0, 0, 1}; double[] predArray = {1, 9, 0, 1, 8, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {1.5, 2.f}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(14.004333, total); session.evaluate(3.5, count); session.evaluate(4.0012328, result); @@ -131,21 +127,31 @@ public void testLabelSmoothing() { float labelSmoothing = 0.1F; CategoricalCrossentropy instance = new CategoricalCrossentropy<>( - tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); int[] trueArray = {0, 1, 0, 0, 0, 1}; double[] predArray = {1, 9, 0, 1, 8, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(7.3356137, total); session.evaluate(2, count); session.evaluate(3.6678069, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>("CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java index 78b25a21b60..728242eb59a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -32,8 +34,7 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CategoricalHinge instance = - new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new CategoricalHinge<>("CH_testUnweighted", 1001L, TFloat64.class); int[] trueArray = { 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, @@ -49,11 +50,11 @@ public void testUnweighted() { Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 5))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(2., total); session.evaluate(4, count); session.evaluate(0.5, result); @@ -65,8 +66,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CategoricalHinge instance = - new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new CategoricalHinge<>("CH_testWeighted", 1001L, TFloat64.class); int[] trueArray = { 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, @@ -84,14 +84,24 @@ public void testWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 5))); Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(3.5F, total); session.evaluate(7, count); session.evaluate(0.5, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + CategoricalHinge instance = new CategoricalHinge<>(null, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index 18410416c42..f40e0647dcb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -32,18 +34,17 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CosineSimilarity instance = - new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); - session.run(instance.resetStates()); + new CosineSimilarity<>("CS_testUnweighted", 1001L, TFloat32.class); int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4, 8, 12, 8, 1, 3}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.3744381F, total); session.evaluate(2, count); session.evaluate(0.18721905F, result); @@ -55,8 +56,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); CosineSimilarity instance = - new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); - session.run(instance.resetStates()); + new CosineSimilarity<>("CS_testWeighted", 1001L, TFloat32.class); int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4, 8, 12, 8, 1, 3}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); @@ -64,11 +64,11 @@ public void testWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new float[] {1.2f, 3.4f}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(-0.3119840621948241F, total); session.evaluate(4.6, count); session.evaluate(-0.06782262221626612F, result); @@ -81,21 +81,30 @@ public void test_axis() { Ops tf = session.getTF(); int axis = 1; CosineSimilarity instance = - new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new CosineSimilarity<>("CS_testWeighted", axis, 1001L, TFloat32.class); int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4, 8, 12, 8, 1, 3}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.3744381F, total); session.evaluate(2, count); session.evaluate(0.18721905F, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + CosineSimilarity instance = new CosineSimilarity<>(null, 1, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java index 4bd8d99586e..34e41f50fa3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -45,11 +47,10 @@ public void testUnweighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); - FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + FalseNegatives instance = new FalseNegatives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(3.0, result); } @@ -63,11 +64,10 @@ public void testWeighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); Operand sampleWeight = tf.constant(this.sampleWeightArray); - FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + FalseNegatives instance = new FalseNegatives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(5.0, result); } @@ -95,11 +95,10 @@ public void testUnweightedWithThresholds() { {1, 1, 1, 1} }); FalseNegatives instance = - new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + new FalseNegatives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float[] expected = new float[] {1.f, 4.f, 6.f}; session.evaluate(expected, result); } @@ -129,13 +128,23 @@ public void testWeightedWithThresholds() { Operand sampleWeight = tf.constant(new double[][] {{3.0}, {5.0}, {7.0}, {4.0}}); FalseNegatives instance = - new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + new FalseNegatives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double[] expected = new double[] {4., 16., 23.}; session.evaluate(expected, result); } } + + /** Test that Eager mode throws IllegalArgument Exception */ + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + FalseNegatives instance = + new FalseNegatives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java index 2584c7a3244..a4a105e2dfc 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -45,11 +47,10 @@ public void testUnweighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); - FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + FalsePositives instance = new FalsePositives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(7.0, result); } @@ -63,11 +64,10 @@ public void testWeighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); Operand sampleWeight = tf.constant(this.sampleWeightArray); - FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + FalsePositives instance = new FalsePositives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(14.0, result); } @@ -95,11 +95,10 @@ public void testUnweightedWithThresholds() { {1, 1, 1, 1} }); FalsePositives instance = - new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + new FalsePositives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float[] expected = new float[] {7.f, 4.f, 2.f}; session.evaluate(expected, result); } @@ -136,13 +135,22 @@ public void testWeightedWithThresholds() { {19.0, 23.0, 29.0, 31.0} }); FalsePositives instance = - new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + new FalsePositives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double[] expected = new double[] {125., 42., 12.}; session.evaluate(expected, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + FalsePositives instance = + new FalsePositives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 90531d21fde..cfa22415b73 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -32,18 +34,17 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + Hinge instance = new Hinge<>("Hinge_testUnweighted", 1001L, TFloat64.class); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(1.0125, total); session.evaluate(2, count); session.evaluate(.5062500, result); @@ -54,29 +55,34 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + Hinge instance = new Hinge<>("Hinge_testWeighted", 1001L, TFloat64.class); int[] trueArray = { -1, 1, -1, 1, -1, -1, 1, 1 }; - float[] predArray = { - -0.3f, 0.2f, -0.1f, 1.6f, - -0.25f, -1.f, 0.5f, 0.6f - }; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(new double[] {1.5, 2.}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(1.7250f, total); session.evaluate(3.5, count); session.evaluate(.49285714f, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + Hinge instance = new Hinge<>(null, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 267578a492c..d14346a3486 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -32,18 +34,17 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); KLDivergence instance = - new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new KLDivergence<>("KLD_testUnweighted", 1001L, TFloat64.class); float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; float[][] predArray = {{.4f, .9f, .12f}, {.36f, .3f, .4f}}; Operand labels = tf.constant(trueArray); Operand predictions = tf.constant(predArray); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(1.1921477, total); session.evaluate(2, count); session.evaluate(0.5960738, result); @@ -55,8 +56,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); KLDivergence instance = - new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new KLDivergence<>("KLD_testWeighted", 1001L, TFloat64.class); float[] trueArray = { .5f, .8f, .12f, .7f, .43f, .8f @@ -70,14 +70,23 @@ public void testWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[][] {{1.2}, {3.4}}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(4.015142, total); session.evaluate(4.6, count); session.evaluate(0.872857, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + KLDivergence instance = new KLDivergence<>(null, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java index 1b5b8fb7d49..bc99dcfd085 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -33,19 +35,18 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); LogCoshError instance = - new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new LogCoshError<>("LogCosh_testUnweighted", 1001L, TFloat64.class); float[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4, 8, 12, 8, 1, 3}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(4.829245, result); session.evaluate(9.65849, total); session.evaluate(2, count); @@ -57,8 +58,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); LogCoshError instance = - new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new LogCoshError<>("LogCosh_testWeighted", 1001L, TFloat64.class); int[] trueArray = {1, 9, 2, -5, -2, 6}; float[] predArray = {4, 8, 12, 8, 1, 3}; double[][] sampleArray = {{1.2}, {3.4}}; @@ -67,14 +67,23 @@ public void testWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(sampleArray); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(5.2178759, result); session.evaluate(24.002228, total); session.evaluate(4.6, count); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + LogCoshError instance = new LogCoshError<>(null, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java index 984895f2ad9..7878fdcf841 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -33,11 +36,9 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); MeanAbsoluteError instance = - new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); - session.evaluate(0.0f, instance.getTotal()); - session.evaluate(0f, instance.getCount()); - session.evaluate(0.f, instance.getCount()); + new MeanAbsoluteError<>("MAE_testUnweighted", 1001L, TFloat64.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); int[] trueArray = { 0, 1, 0, 1, 0, @@ -54,16 +55,16 @@ public void testUnweighted() { Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); - Op op = instance.updateState(yTrue, yPrediction, null); + Op op = instance.updateState(tf, yTrue, yPrediction, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(2.0, total); session.evaluate(4, count); session.evaluate(0.5, result); - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); session.evaluate(0.0, instance.getTotal()); session.evaluate(0, instance.getCount()); session.evaluate(0., instance.getCount()); @@ -75,11 +76,9 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); MeanAbsoluteError instance = - new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); - session.evaluate(0.0, instance.getTotal()); - session.evaluate(0, instance.getCount()); - session.evaluate(0., instance.getCount()); + new MeanAbsoluteError<>("MAE_testWeighted", 1001L, TFloat64.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); int[] trueArray = { 0, 1, 0, 1, 0, @@ -98,19 +97,28 @@ public void testWeighted() { tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); - Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + Op op = instance.updateState(tf, yTrue, yPrediction, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(3.8, total); session.evaluate(7, count); session.evaluate(0.54285, result); - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); session.evaluate(0.0, instance.getTotal()); session.evaluate(0, instance.getCount()); session.evaluate(0., instance.getCount()); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanAbsoluteError instance = new MeanAbsoluteError<>("MAE", 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java index 0b9e7f6b538..a38095d4cf7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -35,11 +38,9 @@ public void testUnweighted() { session.setEpsilon(1E-6f); Ops tf = session.getTF(); MeanAbsolutePercentageError instance = - new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); - session.run(instance.resetStates()); - session.evaluate(0.0f, instance.getTotal()); - session.evaluate(0f, instance.getCount()); - session.evaluate(0.f, instance.getCount()); + new MeanAbsolutePercentageError<>("MAPE_testUnweighted", 1001L, TFloat32.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); int[] trueArray = { 0, 1, 0, 1, 0, @@ -58,13 +59,13 @@ public void testUnweighted() { Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); - Op op = instance.updateState(yTrue, yPrediction, null); + Op op = instance.updateState(tf, yTrue, yPrediction, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(1.4E9f, total); session.evaluate(4f, count); session.evaluate(35e7f, result); @@ -77,11 +78,9 @@ public void testWeighted() { session.setEpsilon(1E-6f); Ops tf = session.getTF(); MeanAbsolutePercentageError instance = - new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); - session.evaluate(0.0, instance.getTotal()); - session.evaluate(0, instance.getCount()); - session.evaluate(0, instance.getCount()); + new MeanAbsolutePercentageError<>("MAPE_testWeighted", 1001L, TFloat64.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); long[] trueArray = { 0, 1, 0, 1, 0, @@ -100,16 +99,26 @@ public void testWeighted() { tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); Operand sampleWeight = tf.constant(new double[] {1.f, 1.5f, 2.f, 2.5f}); - Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + Op op = instance.updateState(tf, yTrue, yPrediction, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(2.800000067278928E9, total); session.evaluate(7, count); session.evaluate(4.000000096112754E8, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanAbsolutePercentageError instance = + new MeanAbsolutePercentageError<>("MAPE", 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java index 686e6371bc0..f055497ff73 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -34,11 +36,10 @@ public void testUnweighted() { Ops tf = session.getTF().withSubScope("testUnweighted"); Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); Operand labels = tf.constant(new long[] {0, 0, 1, 1}); - MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + MeanIoU instance = new MeanIoU<>(numClasses, 1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double expected_result = (1. / (2. + 2. - 1.) + 1. / (2. + 2. - 1.)) / 2.; session.evaluate(expected_result, result); } @@ -51,11 +52,10 @@ public void testWeighted() { Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); Operand labels = tf.constant(new long[] {0, 0, 1, 1}); Operand sampleWeight = tf.constant(new float[] {0.2f, 0.3f, 0.4f, 0.1f}); - MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + MeanIoU instance = new MeanIoU<>(numClasses, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; session.evaluate(expected_result, result); } @@ -69,11 +69,10 @@ public void testMultiDimInput() { Operand predictions = tf.constant(new long[][] {{0, 1}, {0, 1}}); Operand labels = tf.constant(new long[][] {{0, 0}, {1, 1}}); Operand sampleWeight = tf.constant(new float[][] {{0.2f, 0.3f}, {0.4f, 0.1f}}); - MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + MeanIoU instance = new MeanIoU<>(numClasses, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; session.evaluate(expected_result, result); } @@ -83,10 +82,9 @@ public void testMultiDimInput() { public void testZeroValidEntries() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF().withSubScope("testZeroValidEntries"); - MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Operand result = instance.result(); - session.evaluate(0.0f, result); + MeanIoU instance = new MeanIoU<>(numClasses, 1001L, TFloat32.class); + Operand result = instance.result(tf, TFloat32.class); + session.evaluate(0f, result); } } @@ -97,13 +95,22 @@ public void testZeroAndNonZeroEntries() { Operand predictions = tf.constant(new float[] {1}); Operand labels = tf.constant(new int[] {1}); - MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + MeanIoU instance = new MeanIoU<>(numClasses, 1001L, TFloat32.class); session.initialize(); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float expected_result = (0f + 1f / (1f + 1f - 1f)) / 1f; session.evaluate(expected_result, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanIoU instance = new MeanIoU<>(numClasses, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java index ce5d87869ee..4504d1437d0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java @@ -14,6 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.tensorflow.framework.utils.CastHelper.cast; import org.junit.jupiter.api.Test; @@ -36,13 +37,10 @@ public void testUnweighted() { Operand predictions = tf.constant(predArray); Operand labels = tf.constant(trueArray); - MeanRelativeError instance = - new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); - session.initialize(); - session.run(instance.resetStates()); - Op update = instance.updateState(labels, predictions, null); + MeanRelativeError instance = new MeanRelativeError<>(labels, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); double expected_result = 1.25; session.evaluate(expected_result, result); @@ -61,13 +59,10 @@ public void testWeighted() { Operand labels = tf.constant(trueArray); Operand sampleWeight = tf.constant(sampleWeightArray); - MeanRelativeError instance = - new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); - session.initialize(); - session.run(instance.resetStates()); - Op update = instance.updateState(labels, predictions, sampleWeight); + MeanRelativeError instance = new MeanRelativeError<>(labels, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); double expectedResult = 1.3; session.evaluate(expectedResult, result); @@ -86,15 +81,23 @@ public void testZeroNormalizer() { MeanRelativeError instance = new MeanRelativeError<>( - tf, cast(tf, tf.zerosLike(labels), TFloat32.class), 1001L, TFloat32.class); - session.initialize(); - session.run(instance.resetStates()); - Op update = instance.updateState(labels, predictions, null); + cast(tf, tf.zerosLike(labels), TFloat32.class), 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); double expectedResult = 0; session.evaluate(expectedResult, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanRelativeError instance = + new MeanRelativeError<>(new float[] {0, 0, 0, 0}, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java index e42052a9ef1..dec3e96787a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -34,11 +37,9 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); MeanSquaredError instance = - new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); - session.evaluate(0.0, instance.getTotal()); - session.evaluate(0, instance.getCount()); - session.evaluate(0., instance.getCount()); + new MeanSquaredError<>("MSE_testUnweighted", 1001L, TFloat64.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); int[] trueArray = { 0, 1, 0, 1, 0, @@ -55,11 +56,11 @@ public void testUnweighted() { Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); - Op op = instance.updateState(yTrue, yPrediction, null); + Op op = instance.updateState(tf, yTrue, yPrediction, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(2.0, total); session.evaluate(4, count); session.evaluate(0.5, result); @@ -71,11 +72,9 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); MeanSquaredError instance = - new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); - session.evaluate(0.0, instance.getTotal()); - session.evaluate(0, instance.getCount()); - session.evaluate(0., instance.getCount()); + new MeanSquaredError<>("MSE_testWeighted", 1001L, TFloat64.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); long[] trueArray = { 0, 1, 0, 1, 0, @@ -94,14 +93,23 @@ public void testWeighted() { tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); - Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + Op op = instance.updateState(tf, yTrue, yPrediction, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(3.8, total); session.evaluate(7, count); session.evaluate(0.542857, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanSquaredError instance = new MeanSquaredError<>("MSE", 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java index e68d63b8778..88d04205c37 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -33,11 +36,9 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); MeanSquaredLogarithmicError instance = - new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); - session.run(instance.resetStates()); - session.evaluate(0.0f, instance.getTotal()); - session.evaluate(0f, instance.getCount()); - session.evaluate(0.f, instance.getCount()); + new MeanSquaredLogarithmicError<>("MSLE_testUnweighted", 1001L, TFloat32.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); int[] trueArray = { 0, 1, 0, 1, 0, @@ -54,11 +55,11 @@ public void testUnweighted() { Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); - Op op = instance.updateState(yTrue, yPrediction, null); + Op op = instance.updateState(tf, yTrue, yPrediction, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.96090573f, total); session.evaluate(4f, count); session.evaluate(0.24022f, result); @@ -70,11 +71,9 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); MeanSquaredLogarithmicError instance = - new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); - session.evaluate(0.0, instance.getTotal()); - session.evaluate(0, instance.getCount()); - session.evaluate(0., instance.getCount()); + new MeanSquaredLogarithmicError<>("MSLE_testWeighted", 1001L, TFloat64.class); + assertNull(instance.getTotal()); + assertNull(instance.getCount()); int[] trueArray = { 0, 1, 0, 1, 0, @@ -93,14 +92,24 @@ public void testWeighted() { tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); - Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + Op op = instance.updateState(tf, yTrue, yPrediction, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(1.8257208, total); session.evaluate(7, count); session.evaluate(0.26082, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanSquaredLogarithmicError instance = + new MeanSquaredLogarithmicError<>("MSLE", 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java index 3fb11f86b45..83db73ab346 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -31,18 +33,18 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Operand values = tf.constant(new long[] {100, 40}); - MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + MeanTensor instance = new MeanTensor<>(1001L, TFloat64.class); session.initialize(); - Op update = instance.updateState(values, null); + Op update = instance.updateState(tf, values, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double[] expected_result = new double[] {100, 40}; session.evaluate(expected_result, result); session.evaluate(expected_result, instance.getTotal()); session.evaluate(new double[] {1, 1}, instance.getCount()); - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); session.evaluate(new double[] {0, 0}, instance.getTotal()); session.evaluate(new double[] {0, 0}, instance.getCount()); } @@ -53,13 +55,13 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Operand values = tf.constant(new long[] {100, 30}); - MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + MeanTensor instance = new MeanTensor<>(1001L, TFloat64.class); session.initialize(); // check scalar weight - Op update = instance.updateState(values, tf.constant(0.5f)); + Op update = instance.updateState(tf, values, tf.constant(0.5f)); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double[] expected_result = new double[] {100, 30}; session.evaluate(expected_result, result); session.evaluate(new double[] {50, 15}, instance.getTotal()); @@ -67,9 +69,9 @@ public void testWeighted() { // check weights not scalar and weights rank matches values rank values = tf.constant(new long[] {1, 5}); - update = instance.updateState(values, tf.constant(new double[] {1f, 0.2f})); + update = instance.updateState(tf, values, tf.constant(new double[] {1f, 0.2f})); session.run(update); - result = instance.result(); + result = instance.result(tf, TFloat64.class); expected_result = new double[] {51 / 1.5, 16 / 0.7}; session.evaluate(expected_result, result); session.evaluate(new double[] {51, 16}, instance.getTotal()); @@ -77,9 +79,9 @@ public void testWeighted() { // check weights broadcast values = tf.constant(new long[] {1, 2}); - update = instance.updateState(values, tf.constant(0.5f)); + update = instance.updateState(tf, values, tf.constant(0.5f)); session.run(update); - result = instance.result(); + result = instance.result(tf, TFloat64.class); expected_result = new double[] {51.5 / 2, 17 / 1.2}; session.evaluate(expected_result, result); session.evaluate(new double[] {51.5, 17}, instance.getTotal()); @@ -88,9 +90,9 @@ public void testWeighted() { // check weights squeeze values = tf.constant(new long[] {1, 5}); Operand sampleWeight = tf.constant(new double[][] {{1}, {0.2}}); - update = instance.updateState(values, sampleWeight); + update = instance.updateState(tf, values, sampleWeight); session.run(update); - result = instance.result(); + result = instance.result(tf, TFloat64.class); expected_result = new double[] {52.5 / 3, 18 / 1.4}; session.evaluate(expected_result, result); session.evaluate(new double[] {52.5, 18}, instance.getTotal()); @@ -104,16 +106,25 @@ public void testWeightedExpand() { Ops tf = session.getTF(); // check weights expand - MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat32.class); + MeanTensor instance = new MeanTensor<>(1001L, TFloat32.class); Operand values = tf.constant(new long[][] {{1}, {5}}); Operand sampleWeight = tf.constant(new float[] {1f, 0.2f}); - Op update = instance.updateState(values, sampleWeight); + Op update = instance.updateState(tf, values, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(tf.constant(new float[][] {{1f}, {5f}}), result); session.evaluate(tf.constant(new float[][] {{1f}, {1f}}), instance.getTotal()); session.evaluate(tf.constant(new float[][] {{1f}, {0.2f}}), instance.getCount()); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + MeanTensor instance = new MeanTensor<>(1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index 5631bac15ee..0bd84f81561 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -32,19 +34,17 @@ class PoissonTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + Poisson instance = new Poisson<>("Poisson_testUnweighted", 1001L, TFloat64.class); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(-6.6131644, total); session.evaluate(2, count); session.evaluate(-3.3065822, result); @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); - session.run(instance.resetStates()); + Poisson instance = new Poisson<>("Poisson_testWeighted", 1001L, TFloat32.class); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; @@ -65,14 +64,23 @@ public void testWeighted() { tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new float[] {1.2f, 3.4f}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(-12.29468f, total); session.evaluate(4.6f, count); session.evaluate(-2.6727562f, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + Poisson instance = new Poisson<>("Poisson", 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java index a817a3dc5df..8132b74d7cd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Random; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -24,11 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class PrecisionAtRecallTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private final Random random = new Random(); @@ -37,9 +36,7 @@ public class PrecisionAtRecallTest { public void testValueIsIdempotent() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - PrecisionAtRecall instance = - new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + PrecisionAtRecall instance = new PrecisionAtRecall<>(0.7f, 1001L, TFloat32.class); Operand predictions = tf.random.randomUniform( @@ -48,13 +45,14 @@ public void testValueIsIdempotent() { tf.random.randomUniform( tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); for (int i = 0; i < 10; i++) session.run(update); - Operand initialPrecision = instance.result(); + Operand initialPrecision = instance.result(tf, TFloat32.class); - for (int i = 0; i < 10; i++) session.evaluate(initialPrecision, instance.result()); + for (int i = 0; i < 10; i++) + session.evaluate(initialPrecision, instance.result(tf, TFloat32.class)); } } @@ -73,19 +71,17 @@ private int[][] generateRandomArray(int dim1, int dim2) { public void testUnweightedAllCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - PrecisionAtRecall instance = - new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + PrecisionAtRecall instance = new PrecisionAtRecall<>(0.7f, 1001L, TFloat32.class); int[][] predArray = generateRandomArray(100, 1); int[][] trueArray = new int[100][1]; // 100,1 System.arraycopy(predArray, 0, trueArray, 0, predArray.length); Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1f, precision); } @@ -95,17 +91,15 @@ public void testUnweightedAllCorrect() { public void testUnweightedHighRecall() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - PrecisionAtRecall instance = - new PrecisionAtRecall<>(tf, 0.8f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + PrecisionAtRecall instance = new PrecisionAtRecall<>(0.8f, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.5f, 0.4f, 0.5f, 0.6f, 0.8f, 0.9f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.8f, precision); } @@ -115,17 +109,15 @@ public void testUnweightedHighRecall() { public void testUnweightedLowRecall() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - PrecisionAtRecall instance = - new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + PrecisionAtRecall instance = new PrecisionAtRecall<>(0.4f, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.15f, 0.25f, 0.26f, 0.26f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.5f, precision); } @@ -135,19 +127,17 @@ public void testUnweightedLowRecall() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - PrecisionAtRecall instance = - new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + PrecisionAtRecall instance = new PrecisionAtRecall<>(0.4f, 1001L, TFloat32.class); Operand predictions = tf.constant( new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); Operand sampleWeight = tf.constant(new float[] {2, 2, 1, 1, 1, 1, 1, 2, 2, 2}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(2.f / 3.f, precision); } @@ -156,24 +146,23 @@ public void testWeighted() { @Test public void testInvalidSensitivity() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new PrecisionAtRecall<>(tf, -1f, 1001L, TFloat32.class); - } - }); + IllegalArgumentException.class, () -> new PrecisionAtRecall<>(-1f, 1001L, TFloat32.class)); } @Test public void testInvalidNumThresholds() { assertThrows( IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new PrecisionAtRecall<>(tf, 0.7f, -1, 1001L, TFloat32.class); - } - }); + () -> new PrecisionAtRecall<>(0.7f, -1, 1001L, TFloat32.class)); + } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(0.7f, 9, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java index cfe5b483e2b..b195432115e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -35,8 +37,7 @@ public void testValueIsIdempotent() { Ops tf = session.getTF(); Precision instance = - new Precision<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new Precision<>(new float[] {0.3f, 0.72f}, 1001L, TFloat64.class); Operand predictions = tf.random.randomUniform( tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); @@ -44,16 +45,16 @@ public void testValueIsIdempotent() { tf.random.randomUniform( tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); for (int i = 0; i < 10; i++) { session.run(update); } - Operand initialPrecision = instance.result(); + Operand initialPrecision = instance.result(tf, TFloat64.class); for (int i = 0; i < 10; i++) { - session.evaluate(initialPrecision, instance.result()); + session.evaluate(initialPrecision, instance.result(tf, TFloat64.class)); } } } @@ -62,14 +63,13 @@ public void testValueIsIdempotent() { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Precision instance = new Precision<>(tf, 1001L, TFloat64.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(1001L, TFloat64.class); Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); session.evaluate(0.5, precision); } } @@ -78,15 +78,14 @@ public void testUnweighted() { public void testUnweightedAllIncorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Precision instance = new Precision<>(tf, 0.5f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(0.5f, 1001L, TFloat32.class); Operand predictions = tf.random.randomUniformInt(tf.constant(Shape.of(100, 1)), tf.constant(0), tf.constant(2)); Operand labels = tf.math.sub(tf.constant(1), predictions); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.0f, precision); } } @@ -95,15 +94,14 @@ public void testUnweightedAllIncorrect() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Precision instance = new Precision<>(tf, 1001L, TFloat64.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(1001L, TFloat64.class); Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}, {1, 0, 1, 0}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}, {1, 0, 0, 1}}); Operand sampleWeight = tf.constant(new double[][] {{1, 2, 3, 4}, {4, 3, 2, 1}}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); double weightedTP = 3.0f + 4.0f; double weightedPositives = (1.0f + 3.0f) + (4.0f + 2.0f); @@ -117,14 +115,13 @@ public void testWeighted() { public void testDivByZero() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Precision instance = new Precision<>(tf, 1001L, TFloat64.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(1001L, TFloat64.class); Operand predictions = tf.constant(new int[] {0, 0, 0, 0}); Operand labels = tf.constant(new int[] {0, 0, 0, 0}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); session.evaluate(0, precision); } } @@ -134,14 +131,13 @@ public void testUnweightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Precision instance = - new Precision<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new Precision<>(new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{1f, 0f, 0.6f, 0f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); float[] expected = new float[] {0.5f, 0.f}; @@ -154,15 +150,14 @@ public void testWeightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Precision instance = - new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new Precision<>(new float[] {0.5f, 1.f}, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); Operand sampleWeight = tf.constant(new float[][] {{4, 0}, {3, 1}}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); float weightedTP = 0f + 3.f; float weightedPositives = (0f + 3.f) + (4.f + 0.f); @@ -178,15 +173,14 @@ public void testMultipleUpdates() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Precision instance = - new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new Precision<>(new float[] {0.5f, 1.f}, 1001L, TFloat64.class); Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); Operand sampleWeight = tf.constant(new double[][] {{4, 0}, {3, 1}}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); for (int i = 0; i < 2; i++) session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); double weighted_tp = (0 + 3.) + (0 + 3.); double weighted_positives = ((0 + 3.) + (4. + 0.)) + ((0 + 3.) + (4. + 0.)); @@ -202,13 +196,12 @@ public void testUnweightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 3 - Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(null, 3, null, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1.0f / 3.0f, precision); } } @@ -218,21 +211,20 @@ public void testWeightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 3 - Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(null, 3, null, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0.2f, 0.1f, 0.4f, 0f, 0.2f}); Operand labels = tf.constant(new long[] {0, 1, 1, 0, 1}); Operand sampleWeight = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); labels = tf.constant(new long[][] {{1, 0, 1, 1, 1}}); - update = instance.updateState(labels, predictions, tf.constant(3.f)); + update = instance.updateState(tf, labels, predictions, tf.constant(3.f)); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); float tp = (2f + 5f) + (3f + 3f); float predicted_positives = (1f + 2f + 5f) + (3f + 3f + 3f); @@ -246,14 +238,13 @@ public void testUnweightedClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set classId to 2 - Precision instance = new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(null, null, 2, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1, precision); session.evaluate(1, instance.getTruePositives()); @@ -261,9 +252,9 @@ public void testUnweightedClassId() { predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); - update = instance.updateState(labels, predictions, null); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - precision = instance.result(); + precision = instance.result(tf, TFloat32.class); session.evaluate(1, precision); session.evaluate(1, instance.getTruePositives()); @@ -271,9 +262,9 @@ public void testUnweightedClassId() { predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); labels = tf.constant(new long[][] {{0, 1, 0, 0, 0}}); - update = instance.updateState(labels, predictions, null); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - precision = instance.result(); + precision = instance.result(tf, TFloat32.class); session.evaluate(0.5f, precision); session.evaluate(1, instance.getTruePositives()); @@ -286,14 +277,13 @@ public void testUnweightedTopKAndClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK and classId to 2 - Precision instance = new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(null, 2, 2, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0f, 0.2f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1, precision); session.evaluate(1, instance.getTruePositives()); @@ -301,9 +291,9 @@ public void testUnweightedTopKAndClassId() { predictions = tf.constant(new float[][] {{1f, 1f, 0.9f, 1f, 1f}}); labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); - update = instance.updateState(labels, predictions, null); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - precision = instance.result(); + precision = instance.result(tf, TFloat32.class); session.evaluate(1, precision); session.evaluate(1, instance.getTruePositives()); @@ -316,18 +306,26 @@ public void testUnweightedTopKAndThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 2 - Precision instance = new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Precision instance = new Precision<>(0.7f, 2, null, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 1}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1, precision); session.evaluate(1, instance.getTruePositives()); session.evaluate(0, instance.getFalsePositives()); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(0.7f, 2, null, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java index bd3a5273668..36dba3180b7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Random; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -24,11 +28,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class RecallAtPrecisionTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private final Random random = new Random(); @@ -37,9 +36,7 @@ public class RecallAtPrecisionTest { public void testValueIsIdempotent() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - RecallAtPrecision instance = - new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + RecallAtPrecision instance = new RecallAtPrecision<>(0.7f, 1001L, TFloat32.class); Operand predictions = tf.random.randomUniform( @@ -49,16 +46,16 @@ public void testValueIsIdempotent() { tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); labels = tf.math.mul(labels, tf.constant(2.0f)); - Op update = instance.updateState(labels, predictions); + Op update = instance.updateState(tf, labels, predictions); for (int i = 0; i < 10; i++) { session.run(update); } - Operand initialPrecision = instance.result(); + Operand initialPrecision = instance.result(tf, TFloat32.class); for (int i = 0; i < 10; i++) { - session.evaluate(initialPrecision, instance.result()); + session.evaluate(initialPrecision, instance.result(tf, TFloat32.class)); } } } @@ -78,18 +75,16 @@ private int[][] generateRandomArray(int dim1, int dim2, int maxVal) { public void test_unweighted_all_correct() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - RecallAtPrecision instance = - new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + RecallAtPrecision instance = new RecallAtPrecision<>(0.7f, 1001L, TFloat32.class); int[][] predArray = generateRandomArray(100, 1, 2); int[][] trueArray = new int[100][1]; // 100,1 System.arraycopy(predArray, 0, trueArray, 0, predArray.length); Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1f, precision); } @@ -99,9 +94,7 @@ public void test_unweighted_all_correct() { public void testUnweightedHighPrecision() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - RecallAtPrecision instance = - new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + RecallAtPrecision instance = new RecallAtPrecision<>(0.75f, 1001L, TFloat32.class); Operand predictions = tf.constant( new float[] { @@ -109,10 +102,10 @@ public void testUnweightedHighPrecision() { }); Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.5f, precision); } @@ -123,8 +116,7 @@ public void testUnweightedLowPrecision() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); RecallAtPrecision instance = - new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new RecallAtPrecision<>(2.0f / 3f, 1001L, TFloat32.class); Operand predictions = tf.constant( new float[] { @@ -132,10 +124,10 @@ public void testUnweightedLowPrecision() { }); Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(5.f / 6f, precision); } @@ -145,18 +137,16 @@ public void testUnweightedLowPrecision() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - RecallAtPrecision instance = - new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + RecallAtPrecision instance = new RecallAtPrecision<>(0.75f, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.5f, 0.6f, 0.9f, 0.9f}); Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1}); Operand sampleWeight = tf.constant(new float[] {1, 2, 1, 2, 1, 2, 1}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.6f, precision); } @@ -167,15 +157,14 @@ public void testUnachievablePrecision() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); RecallAtPrecision instance = - new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new RecallAtPrecision<>(2.0f / 3f, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.9f}); Operand labels = tf.constant(new long[] {1, 1, 0, 0}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); // The highest possible precision is 1/2 which is below the required session.evaluate(0f, precision); } @@ -184,24 +173,23 @@ public void testUnachievablePrecision() { @Test public void test_invalid_sensitivity() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new RecallAtPrecision<>(tf, -1f, 1001L, TFloat32.class); - } - }); + IllegalArgumentException.class, () -> new RecallAtPrecision<>(-1f, 1001L, TFloat32.class)); } @Test public void test_invalid_num_thresholds() { assertThrows( IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new RecallAtPrecision<>(tf, 0.7f, -1, 1001L, TFloat32.class); - } - }); + () -> new RecallAtPrecision<>(0.7f, -1, 1001L, TFloat32.class)); + } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(2.0f / 3f, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java index bd9fbb1ab66..e820cbe0d74 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java @@ -14,6 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Random; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -22,8 +25,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import java.util.Random; - public class RecallTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private final Random random = new Random(); @@ -32,20 +33,19 @@ public class RecallTest { public void testValueIsIdempotent() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(new float[] {0.3f, 0.72f}, 1001L, TFloat32.class); Operand predictions = tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); Operand labels = tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); for (int i = 0; i < 10; i++) session.run(update); - Operand initialRecall = instance.result(); - for (int i = 0; i < 10; i++) session.evaluate(initialRecall, instance.result()); + Operand initialRecall = instance.result(tf, TFloat32.class); + for (int i = 0; i < 10; i++) + session.evaluate(initialRecall, instance.result(tf, TFloat32.class)); } } @@ -53,15 +53,14 @@ public void testValueIsIdempotent() { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{1, 0, 1, 0}}); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5f, instance.result()); + session.evaluate(0.5f, instance.result(tf, TFloat32.class)); } } @@ -80,16 +79,15 @@ private int[][] generateRandomArray(int dim1, int dim2, int maxInt) { public void testUnweightedAllIncorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(1001L, TFloat32.class); int[][] array = generateRandomArray(100, 1, 2); Operand predictions = tf.dtypes.cast(tf.constant(array), TFloat32.class); Operand labels = tf.dtypes.cast(tf.math.sub(tf.constant(1), tf.constant(array)), TFloat32.class); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.f, instance.result()); + session.evaluate(0.f, instance.result(tf, TFloat32.class)); } } @@ -97,8 +95,7 @@ public void testUnweightedAllIncorrect() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(1001L, TFloat32.class); Operand predictions = tf.constant( new float[][] { @@ -118,14 +115,14 @@ public void testWeighted() { {1, 2, 3, 4}, {4, 3, 2, 1} }); - Op update = instance.updateState(labels, predictions, sampleWeights); + Op update = instance.updateState(tf, labels, predictions, sampleWeights); session.run(update); float weightedTp = 3.0f + 1.0f; float weightedT = (2.0f + 3.0f) + (4.0f + 1.0f); float expectedRecall = weightedTp / weightedT; - session.evaluate(expectedRecall, instance.result()); + session.evaluate(expectedRecall, instance.result(tf, TFloat32.class)); } } @@ -133,16 +130,15 @@ public void testWeighted() { public void testDivByZero() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0, 0, 0, 0}); Operand labels = tf.constant(new float[] {0, 0, 0, 0}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0f, instance.result()); + session.evaluate(0f, instance.result(tf, TFloat32.class)); } } @@ -150,17 +146,16 @@ public void testDivByZero() { public void testUnweightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{1, 0, 0.6f, 0}}); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); Float[] expected = new Float[] {0.5f, 0f}; - session.evaluate(expected, instance.result()); + session.evaluate(expected, instance.result(tf, TFloat32.class)); } } @@ -168,21 +163,20 @@ public void testUnweightedWithThreshold() { public void testWeightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(new float[] {0.5f, 1.f}, 1001L, TFloat32.class); Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); - Op update = instance.updateState(labels, predictions, weights); + Op update = instance.updateState(tf, labels, predictions, weights); session.run(update); float weightedTp = 0 + 3.f; float weightedPositives = (0 + 3.f) + (4.f + 0.f); float expectedRecall = weightedTp / weightedPositives; float[] expected = new float[] {expectedRecall, 0f}; - session.evaluate(expected, instance.result()); + session.evaluate(expected, instance.result(tf, TFloat32.class)); } } @@ -190,21 +184,20 @@ public void testWeightedWithThreshold() { public void testMultipleUpdates() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(new float[] {0.5f, 1.f}, 1001L, TFloat32.class); Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); - Op update = instance.updateState(labels, predictions, weights); + Op update = instance.updateState(tf, labels, predictions, weights); for (int i = 0; i < 2; i++) session.run(update); float weightedTp = (0f + 3.f) + (0f + 3.f); float weightedPositives = ((0f + 3.f) + (4.f + 0.f)) + ((0f + 3.f) + (4.f + 0.f)); float expectedRecall = weightedTp / weightedPositives; float[] expected = new float[] {expectedRecall, 0f}; - session.evaluate(expected, instance.result()); + session.evaluate(expected, instance.result(tf, TFloat32.class)); } } @@ -212,16 +205,15 @@ public void testMultipleUpdates() { public void testUnweightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(null, null, 3, null, 1001L, TFloat32.class); Operand labels = tf.constant(new float[][] {{0f, 1f, 1f, 0f, 0f}}); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5f, instance.result()); + session.evaluate(0.5f, instance.result(tf, TFloat32.class)); } } @@ -229,27 +221,26 @@ public void testUnweightedTopK() { public void testWeightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(null, null, 3, null, 1001L, TFloat32.class); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 1}}); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.4f, 0f, 0.2f}}); Operand weights = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); - Op update = instance.updateState(labels, predictions, weights); + Op update = instance.updateState(tf, labels, predictions, weights); session.run(update); labels = tf.constant(new float[][] {{1, 0, 1, 1, 1}}); predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); weights = tf.constant(3.f); - update = instance.updateState(labels, predictions, weights); + update = instance.updateState(tf, labels, predictions, weights); session.run(update); float weightedTp = (2 + 5) + (3 + 3); float weightedPositives = (4 + 2 + 5) + (3 + 3 + 3 + 3); float expectedRecall = weightedTp / weightedPositives; - session.evaluate(expectedRecall, instance.result()); + session.evaluate(expectedRecall, instance.result(tf, TFloat32.class)); } } @@ -257,30 +248,29 @@ public void testWeightedTopK() { public void testUnweightedClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(null, null, null, 2, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.result(tf, TFloat32.class)); session.evaluate(1f, instance.getTruePositives()); session.evaluate(0f, instance.getFalseNegatives()); predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); - update = instance.updateState(labels, predictions, null); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5f, instance.result()); + session.evaluate(0.5f, instance.result(tf, TFloat32.class)); session.evaluate(1f, instance.getTruePositives()); session.evaluate(1f, instance.getFalseNegatives()); predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); labels = tf.constant(new float[][] {{0, 1, 0, 0, 0}}); - update = instance.updateState(labels, predictions, null); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5f, instance.result()); + session.evaluate(0.5f, instance.result(tf, TFloat32.class)); session.evaluate(1f, instance.getTruePositives()); session.evaluate(1f, instance.getFalseNegatives()); } @@ -290,24 +280,23 @@ public void testUnweightedClassId() { public void testUnweightedTopKAndClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(null, null, 2, 2, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0, 0.2f}}); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.result(tf, TFloat32.class)); session.evaluate(1f, instance.getTruePositives()); session.evaluate(0f, instance.getFalseNegatives()); predictions = tf.constant(new float[][] {{1, 1, 0.9f, 1, 1}}); labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); - update = instance.updateState(labels, predictions, null); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5f, instance.result()); + session.evaluate(0.5f, instance.result(tf, TFloat32.class)); session.evaluate(1f, instance.getTruePositives()); session.evaluate(1f, instance.getFalseNegatives()); } @@ -317,17 +306,25 @@ public void testUnweightedTopKAndClassId() { public void testUnweightedTopKAndThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); - session.run(instance.resetStates()); + Recall instance = new Recall<>(null, 0.7f, 2, null, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); Operand labels = tf.constant(new float[][] {{1, 1, 1, 0, 1}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.25f, instance.result()); + session.evaluate(0.25f, instance.result(tf, TFloat32.class)); session.evaluate(1f, instance.getTruePositives()); session.evaluate(3f, instance.getFalseNegatives()); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(null, 0.7f, 2, null, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java index c9ced9f5946..116c929a701 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -30,18 +32,16 @@ public class RootMeanSquaredErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - RootMeanSquaredError instance = - new RootMeanSquaredError<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); + RootMeanSquaredError instance = new RootMeanSquaredError<>(1001L, TFloat32.class); Operand labels = tf.constant(new float[] {2, 4, 6}); Operand predictions = tf.constant(new float[] {1, 3, 2}); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(18, total); session.evaluate(3, count); session.evaluate(Math.sqrt(6), result); @@ -52,21 +52,28 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - RootMeanSquaredError instance = - new RootMeanSquaredError<>(tf, 1001L, TFloat64.class); - session.run(instance.resetStates()); + RootMeanSquaredError instance = new RootMeanSquaredError<>(1001L, TFloat64.class); Operand labels = tf.constant(new float[][] {{2, 4, 6, 8}}); Operand predictions = tf.constant(new float[][] {{1, 3, 2, 3}}); Operand sampleWeight = tf.constant(new double[][] {{0, 1, 0, 1}}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(26, total); session.evaluate(2, count); session.evaluate(Math.sqrt(13), result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = new RootMeanSquaredError<>(1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java index a65dc3b53da..d18ca9813fe 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Random; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -25,11 +29,6 @@ import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt64; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class SensitivityAtSpecificityTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private final Random random = new Random(); @@ -39,8 +38,7 @@ public void testValueIsIdempotent() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SensitivityAtSpecificity instance = - new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SensitivityAtSpecificity<>(0.7f, 1001L, TFloat32.class); Operand predictions = tf.random.randomUniform( tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); @@ -50,13 +48,14 @@ public void testValueIsIdempotent() { labels = tf.math.mul(labels, tf.constant(2.0f)); // instance.setDebug(session.getGraphSession()); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); for (int i = 0; i < 10; i++) session.run(update); - Operand initialSensitivity = instance.result(); + Operand initialSensitivity = instance.result(tf, TFloat32.class); - for (int i = 0; i < 10; i++) session.evaluate(initialSensitivity, instance.result()); + for (int i = 0; i < 10; i++) + session.evaluate(initialSensitivity, instance.result(tf, TFloat32.class)); // instance.setDebug(null); @@ -79,18 +78,17 @@ public void testUnweightedAllCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SensitivityAtSpecificity instance = - new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SensitivityAtSpecificity<>(0.7f, 1001L, TFloat32.class); int[][] predArray = generateRandomArray(100, 1); int[][] trueArray = new int[100][1]; // 100,1 System.arraycopy(predArray, 0, trueArray, 0, predArray.length); Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1f, precision); } @@ -101,16 +99,15 @@ public void testUnweightedHighSpecificity() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SensitivityAtSpecificity instance = - new SensitivityAtSpecificity<>(tf, 0.8f, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new SensitivityAtSpecificity<>(0.8f, 1001L, TFloat64.class); Operand predictions = tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); session.evaluate(0.8, precision); } @@ -121,17 +118,16 @@ public void testUnweightedLowSpecificity() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SensitivityAtSpecificity instance = - new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SensitivityAtSpecificity<>(0.4f, 1001L, TFloat32.class); Operand predictions = tf.constant( new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.6f, precision); } @@ -142,18 +138,17 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SensitivityAtSpecificity instance = - new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new SensitivityAtSpecificity<>(0.4f, 1001L, TFloat64.class); Operand predictions = tf.constant( new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); Operand sampleWeight = tf.constant(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); session.evaluate(0.675, precision); } @@ -163,23 +158,23 @@ public void testWeighted() { public void testInvalidSensitivity() { assertThrows( IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new SensitivityAtSpecificity<>(tf, -1f, 1001L, TFloat32.class); - } - }); + () -> new SensitivityAtSpecificity<>(-1f, 1001L, TFloat32.class)); } @Test public void testInvalidNumThresholds() { assertThrows( IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new SensitivityAtSpecificity<>(tf, 0.7f, -1, 1001L, TFloat32.class); - } - }); + () -> new SensitivityAtSpecificity<>(0.7f, -1, 1001L, TFloat32.class)); + } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(0.4f, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 0aece8c8ac9..4ce84a26ac7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -34,18 +36,17 @@ public void testUnweighted() { Ops tf = session.getTF(); SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( - tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); int[] trueArray = {1, 2}; double[] predictionArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); Operand predictions = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(2.3538785, total); session.evaluate(2, count); session.evaluate(1.1769392, result); @@ -57,18 +58,16 @@ public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SparseCategoricalCrossentropy instance = - new SparseCategoricalCrossentropy<>( - tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new SparseCategoricalCrossentropy<>("SCE_testWeighted", true, -1, 1001L, TFloat64.class); int[] trueArray = {1, 2}; double[] logitsArray = {1, 9, 0, 1, 8, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); - Op op = instance.updateState(labels, logits, null); + Op op = instance.updateState(tf, labels, logits, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(7.002277, total); session.evaluate(2, count); session.evaluate(3.501135, result); @@ -80,9 +79,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SparseCategoricalCrossentropy instance = - new SparseCategoricalCrossentropy<>( - tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SparseCategoricalCrossentropy<>("SCE_testWeighted", false, -1, 1001L, TFloat32.class); int[] trueArray = {1, 2}; double[] predictionArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); @@ -90,11 +87,11 @@ public void testWeighted() { tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new float[] {1.5F, 2.F}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(4.6821103f, total); session.evaluate(3.5f, count); session.evaluate(1.3377458f, result); @@ -106,9 +103,7 @@ public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SparseCategoricalCrossentropy instance = - new SparseCategoricalCrossentropy<>( - tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new SparseCategoricalCrossentropy<>("SCE_testWeighted", true, -1, 1001L, TFloat64.class); int[] trueArray = {1, 2}; double[] predictionArray = {1, 9, 0, 1, 8, 1}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); @@ -116,14 +111,24 @@ public void testWeightedLogits() { tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {1.5, 2}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(14.004333, total); session.evaluate(3.5, count); session.evaluate(4.001232, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>("SCE", true, -1, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java index ff5834eda8e..676b443cd1c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Random; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -26,11 +30,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class SpecificityAtSensitivityTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private final Random random = new Random(); @@ -40,8 +39,7 @@ public void testValueIsIdempotent() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SpecificityAtSensitivity instance = - new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SpecificityAtSensitivity<>(0.7f, 1001L, TFloat32.class); Operand predictions = tf.random.randomUniform( @@ -51,13 +49,14 @@ public void testValueIsIdempotent() { tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); // instance.setDebug(session.getGraphSession()); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); for (int i = 0; i < 10; i++) session.run(update); - Operand initialSpecificity = instance.result(); + Operand initialSpecificity = instance.result(tf, TFloat32.class); - for (int i = 0; i < 10; i++) session.evaluate(initialSpecificity, instance.result()); + for (int i = 0; i < 10; i++) + session.evaluate(initialSpecificity, instance.result(tf, TFloat32.class)); } } @@ -77,8 +76,7 @@ public void testUnweightedAllCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SpecificityAtSensitivity instance = - new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SpecificityAtSensitivity<>(0.7f, 1001L, TFloat32.class); int[][] predArray = generateRandomArray(100, 1); int[][] trueArray = new int[100][1]; // 100,1 System.arraycopy(predArray, 0, trueArray, 0, predArray.length); @@ -86,10 +84,10 @@ public void testUnweightedAllCorrect() { Operand labels = tf.constant(trueArray); labels = tf.math.mul(labels, tf.constant(2)); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(1f, precision); } @@ -100,16 +98,15 @@ public void testUnweightedHighSensitivity() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SpecificityAtSensitivity instance = - new SpecificityAtSensitivity<>(tf, 0.8f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SpecificityAtSensitivity<>(0.8f, 1001L, TFloat32.class); Operand predictions = tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.4f, precision); } @@ -120,17 +117,16 @@ public void testUnweightedLowSensitivity() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SpecificityAtSensitivity instance = - new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new SpecificityAtSensitivity<>(0.4f, 1001L, TFloat64.class); Operand predictions = tf.constant( new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat64.class); session.evaluate(0.6f, precision); } @@ -141,18 +137,17 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SpecificityAtSensitivity instance = - new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SpecificityAtSensitivity<>(0.4f, 1001L, TFloat32.class); Operand predictions = tf.constant( new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); Operand sampleWeight = tf.constant(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand precision = instance.result(); + Operand precision = instance.result(tf, TFloat32.class); session.evaluate(0.4f, precision); } @@ -162,23 +157,23 @@ public void testWeighted() { public void testInvalidSensitivity() { assertThrows( IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new SpecificityAtSensitivity<>(tf, -1f, 1001L, TFloat32.class); - } - }); + () -> new SpecificityAtSensitivity<>(-1f, 1001L, TFloat32.class)); } @Test public void testInvalidNumThresholds() { assertThrows( IllegalArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - new SpecificityAtSensitivity<>(tf, 0.4f, -1, 1001L, TFloat32.class); - } - }); + () -> new SpecificityAtSensitivity<>(0.4f, -1, 1001L, TFloat32.class)); + } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(0.4f, 1001L, TFloat32.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java index 2c80b3451ad..b6f0a2956d4 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -33,24 +35,20 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SquaredHinge instance = - new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); - session.run(instance.resetStates()); + new SquaredHinge<>("SCE_testUnweighted", 1001L, TFloat32.class); int[] trueArray = { 0, 1, 0, 1, 0, 0, 1, 1 }; - float[] predArray = { - -0.3f, 0.2f, -0.1f, 1.6f, - -0.25f, -1.f, 0.5f, 0.6f - }; + float[] predArray = {-0.3f, 0.2f, -0.1f, 1.6f, -0.25f, -1.f, 0.5f, 0.6f}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); - Op op = instance.updateState(labels, predictions, null); + Op op = instance.updateState(tf, labels, predictions, null); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); session.evaluate(0.72812f, total); session.evaluate(2f, count); session.evaluate(0.3640625f, result); @@ -62,29 +60,34 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); SquaredHinge instance = - new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); - session.run(instance.resetStates()); + new SquaredHinge<>("SCE_testWeighted", 1001L, TFloat64.class); int[] trueArray = { 0, 1, 0, 1, 0, 0, 1, 1 }; - double[] predArray = { - -0.3, 0.2, -0.1, 1.6, - -0.25, -1., 0.5, 0.6 - }; + double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); Operand sampleWeight = tf.constant(new double[] {1.5f, 2.f}); - Op op = instance.updateState(labels, predictions, sampleWeight); + Op op = instance.updateState(tf, labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); Variable count = instance.getCount(); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(1.2137499, total); session.evaluate(3.5, count); session.evaluate(0.3467857, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + SquaredHinge instance = new SquaredHinge<>("SCE", 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java index 941f882b8c8..09a35c406db 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java @@ -14,6 +14,10 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -22,8 +26,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertEquals; - public class SumTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -32,22 +34,21 @@ public class SumTest { public void testUnWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sum instance = new Sum<>(tf, 1001L, TFloat32.class); - session.run(instance.resetStates()); - assertEquals(TFloat32.class, instance.getResultType()); - session.evaluate(0f, instance.getTotal()); + Sum instance = new Sum<>(1001L, TFloat32.class); + assertEquals(TFloat32.class, instance.getInternalType()); + assertNull(instance.getTotal()); - Op update = instance.updateState(tf.constant(100f), null); + Op update = instance.updateState(tf, tf.constant(100f), null); session.run(update); - session.evaluate(100f, instance.result()); + session.evaluate(100f, instance.result(tf, TFloat64.class)); session.evaluate(100f, instance.getTotal()); - update = instance.updateState(tf.constant(new float[] {1, 5}), null); + update = instance.updateState(tf, tf.constant(new float[] {1, 5}), null); session.run(update); - session.evaluate(106f, instance.result()); + session.evaluate(106f, instance.result(tf, TFloat64.class)); session.evaluate(106f, instance.getTotal()); - session.run(instance.resetStates()); + session.run(instance.resetStates(tf)); session.evaluate(0f, instance.getTotal()); } } @@ -56,58 +57,68 @@ public void testUnWeighted() { public void testSumWithSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sum instance = new Sum<>(tf, 1001L, TFloat64.class); - session.run(instance.resetStates()); + Sum instance = new Sum<>(1001L, TFloat64.class); // check scalar weight - Op op = instance.updateState(tf.constant(100f), tf.constant(0.5)); + Op op = instance.updateState(tf, tf.constant(100f), tf.constant(0.5)); session.run(op); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(50.0, instance.getTotal()); session.evaluate(50.0, result); // check weights not scalar and weights rank matches values rank op = - instance.updateState(tf.constant(new float[] {1, 5}), tf.constant(new double[] {1, 0.2})); + instance.updateState( + tf, tf.constant(new float[] {1, 5}), tf.constant(new double[] {1, 0.2})); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat64.class); session.evaluate(52., instance.getTotal()); session.evaluate(52., result); // check weights broadcast - op = instance.updateState(tf.constant(new float[] {1, 2}), tf.constant(0.5)); + op = instance.updateState(tf, tf.constant(new float[] {1, 2}), tf.constant(0.5)); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat64.class); session.evaluate(53.5, instance.getTotal()); session.evaluate(53.5, result); // check weights squeeze op = instance.updateState( - tf.constant(new float[] {1, 5}), tf.constant(new double[][] {{1}, {0.2}})); + tf, tf.constant(new float[] {1, 5}), tf.constant(new double[][] {{1}, {0.2}})); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat64.class); session.evaluate(55.5, instance.getTotal()); session.evaluate(55.5, result); // check weights expand op = instance.updateState( - tf.constant(new float[][] {{1}, {5}}), tf.constant(new double[] {1, 0.2})); + tf, tf.constant(new float[][] {{1}, {5}}), tf.constant(new double[] {1, 0.2})); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat64.class); session.evaluate(57.5, instance.getTotal()); session.evaluate(57.5, result); // heck values reduced to the dimensions of weight op = instance.updateState( + tf, tf.constant(new float[][][] {{{1.f, 2.f}, {3.f, 2.f}, {0.5f, 4.f}}}), tf.constant(new double[] {0.5})); session.run(op); - result = instance.result(); + result = instance.result(tf, TFloat64.class); session.evaluate(63.75, instance.getTotal()); session.evaluate(63.75, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java index 023796ba367..95e73e5893d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -30,23 +32,20 @@ public void testCorrectness() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); TopKCategoricalAccuracy instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new TopKCategoricalAccuracy<>("TopK_testUnweighted", 5, 1001L, TFloat64.class); Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); Operand predictions = tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); - Op update = instance.updateState(labels, predictions, null); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(1., instance.result()); + session.evaluate(1., instance.result(tf, TFloat64.class)); // With `k` < 5. - instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); + instance = new TopKCategoricalAccuracy<>("TopK_testUnweighted1", 1, 1001L, TFloat64.class); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5, instance.result()); + session.evaluate(0.5, instance.result(tf, TFloat64.class)); // With `k` > 5. labels = @@ -61,12 +60,10 @@ public void testCorrectness() { {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} }); - instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); + instance = new TopKCategoricalAccuracy<>("TopK_testUnweighted6", 6, 1001L, TFloat64.class); + update = instance.updateState(tf, labels, predictions, null); session.run(update); - session.evaluate(0.5, instance.result()); + session.evaluate(0.5, instance.result(tf, TFloat64.class)); } } @@ -75,8 +72,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); TopKCategoricalAccuracy instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); + new TopKCategoricalAccuracy<>("TopK_testWeighted", 5, 1001L, TFloat64.class); Operand labels = tf.constant( @@ -95,9 +91,19 @@ public void testWeighted() { Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); - Op update = instance.updateState(labels, predictions, sampleWeight); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - session.evaluate(1., instance.result()); + session.evaluate(1., instance.result(tf, TFloat64.class)); + } + } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>("TopK", 5, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java index 1a68c2ed8b8..d6cdc51e2fc 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -45,11 +47,10 @@ public void testUnweighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); - TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + TrueNegatives instance = new TrueNegatives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(3.0, result); } @@ -63,11 +64,10 @@ public void testWeighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); Operand sampleWeight = tf.constant(this.sampleWeightArray); - TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + TrueNegatives instance = new TrueNegatives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(4.0, result); } @@ -95,11 +95,10 @@ public void testUnweightedWithThresholds() { {1, 1, 1, 1} }); TrueNegatives instance = - new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + new TrueNegatives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float[] expected = new float[] {2.f, 5.f, 7.f}; session.evaluate(expected, result); } @@ -129,13 +128,22 @@ public void testWeightedWithThresholds() { Operand sampleWeight = tf.constant(new double[][] {{0.0, 2.0, 3.0, 5.0}}); TrueNegatives instance = - new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + new TrueNegatives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double[] expected = new double[] {5., 15., 23.}; session.evaluate(expected, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + TrueNegatives instance = + new TrueNegatives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java index c22c1245d97..15645355230 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java @@ -14,6 +14,8 @@ =======================================================================*/ package org.tensorflow.framework.metrics; +import static org.junit.jupiter.api.Assertions.assertThrows; + import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; @@ -45,11 +47,10 @@ public void testUnweighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); - TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + TruePositives instance = new TruePositives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(7.0, result); } @@ -63,11 +64,10 @@ public void testWeighted() { Operand predictions = tf.constant(this.predArray); Operand labels = tf.constant(this.trueArray); Operand sampleWeight = tf.constant(this.sampleWeightArray); - TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + TruePositives instance = new TruePositives<>(1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); session.evaluate(12.0, result); } @@ -95,11 +95,10 @@ public void testUnweightedWithThresholds() { {1, 1, 1, 1} }); TruePositives instance = - new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, null); + new TruePositives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + Op update = instance.updateState(tf, labels, predictions, null); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat32.class); float[] expected = new float[] {6.f, 3.f, 1.f}; session.evaluate(expected, result); } @@ -129,13 +128,22 @@ public void testWeightedWithThresholds() { Operand sampleWeight = tf.constant(37.); TruePositives instance = - new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); - session.run(instance.getInitializer()); - Op update = instance.updateState(labels, predictions, sampleWeight); + new TruePositives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + Op update = instance.updateState(tf, labels, predictions, sampleWeight); session.run(update); - Operand result = instance.result(); + Operand result = instance.result(tf, TFloat64.class); double[] expected = new double[] {222., 111., 37.}; session.evaluate(expected, result); } } + + @Test + public void testEagerEnvironment() { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.EAGER)) { + Ops tf = session.getTF(); + TruePositives instance = + new TruePositives<>(new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + assertThrows(IllegalArgumentException.class, () -> instance.updateState(tf, null, null)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java deleted file mode 100644 index eceff2797f8..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ /dev/null @@ -1,120 +0,0 @@ -package org.tensorflow.framework.metrics.impl; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TUint8; -import org.tensorflow.types.family.TType; - -import java.util.Arrays; -import java.util.List; - -import static org.tensorflow.framework.utils.CastHelper.cast; - -class SetsOpsTest { - - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - - List> types = Arrays.asList(TInt32.class, TInt64.class, TUint8.class); - - @Test - @SuppressWarnings({"unchecked", "rawtypes"}) - public void testSetIntersectionMultirow2() { - - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); - Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); - int[][] expected = new int[][] {{1, 9}, {0, 0}}; - Shape expectedShape = Shape.of(2, 2); - for (Class type : types) { - Operand aa = cast(tf, a, type); - Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); - } - } - } - - @Test - @SuppressWarnings({"unchecked", "rawtypes"}) - public void testSetIntersectionDuplicates2d() { - - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand a = tf.constant(new int[][] {{1, 1, 3}}); - Operand b = tf.constant(new int[][] {{1, 1}}); - int[][] expected = {{1}}; - Shape expectedShape = Shape.of(1, 1); - for (Class type : types) { - Operand aa = cast(tf, a, type); - Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); - - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); - } - } - } - - @Test - @SuppressWarnings({"unchecked", "rawtypes"}) - public void testDenseSetDifferenceMultirow2d() { - - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); - Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); - - for (Class type : types) { - Operand aa = cast(tf, a, type); - Operand bb = cast(tf, b, type); - int[][] expected = {{5, 9, 0}, {3, 4, 5}}; - // a- b - Shape expectedShape = Shape.of(2, 3); - Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); - - // b - a - expected = new int[][] {{2, 6}, {1, 2}}; - expectedShape = Shape.of(2, 2); - intersection = SetsOps.difference(tf, aa, bb, false); - - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); - } - } - } - - @Test - @SuppressWarnings({"unchecked", "rawtypes"}) - public void testDenseUnionMultirow2d() { - - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); - Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); - int[][] expected = new int[][] {{5, 0}, {3, 4}}; - for (Class type : types) { - Operand aa = cast(tf, a, type); - Operand bb = cast(tf, b, type); - Shape expectedShape = Shape.of(2, 2); - // a- b - Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); - } - } - } -}