From 8f57a7af1d7e031c2f7dfc6acceff07b00f4ca18 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 31 Dec 2020 19:29:42 -0500 Subject: [PATCH 01/56] Initial checkin --- .../tensorflow/framework/losses/Losses.java | 3 + .../framework/metrics/BinaryCrossentropy.java | 66 ++++ .../metrics/CategoricalCrossentropy.java | 99 ++++++ .../framework/metrics/CategoricalHinge.java | 46 +++ .../framework/metrics/CosineSimilarity.java | 68 ++++ .../tensorflow/framework/metrics/Hinge.java | 46 +++ .../framework/metrics/KLDivergence.java | 46 +++ .../framework/metrics/LogCoshError.java | 49 +++ .../tensorflow/framework/metrics/Mean.java | 40 +++ .../framework/metrics/MeanAbsoluteError.java | 46 +++ .../metrics/MeanAbsolutePercentageError.java | 46 +++ .../framework/metrics/MeanSquaredError.java | 46 +++ .../metrics/MeanSquaredLogarithmicError.java | 46 +++ .../tensorflow/framework/metrics/Metric.java | 320 ++++++++++++++++++ .../framework/metrics/MetricReduction.java | 26 ++ .../tensorflow/framework/metrics/Metrics.java | 192 +++++++++++ .../tensorflow/framework/metrics/Poisson.java | 46 +++ .../SparseCategoricalCrossentropy.java | 54 +++ .../SparseTopKCategoricalAccuracy.java | 65 ++++ .../framework/metrics/SquaredHinge.java | 46 +++ .../metrics/TopKCategoricalAccuracy.java | 63 ++++ .../framework/metrics/impl/LossInterface.java | 36 ++ .../metrics/impl/MeanMetricWrapper.java | 125 +++++++ .../metrics/impl/MetricVariable.java | 122 +++++++ .../framework/metrics/impl/MetricsHelper.java | 154 +++++++++ .../framework/metrics/impl/Reduce.java | 204 +++++++++++ .../metrics/BinaryCrossentropyTest.java | 150 ++++++++ .../metrics/CategoricalCrossentropyTest.java | 151 +++++++++ .../metrics/CategoricalHingeTest.java | 97 ++++++ .../metrics/CosineSimilarityTest.java | 101 ++++++ .../framework/metrics/HingeTest.java | 84 +++++ .../framework/metrics/KLDivergenceTest.java | 83 +++++ .../framework/metrics/LogCoshErrorTest.java | 80 +++++ .../metrics/MeanAbsoluteErrorTest.java | 116 +++++++ .../MeanAbsolutePercentageErrorTest.java | 115 +++++++ .../metrics/MeanSquaredErrorTest.java | 107 ++++++ .../MeanSquaredLogarithmicErrorTest.java | 106 ++++++ .../framework/metrics/PoissonTest.java | 79 +++++ .../SparseCategoricalCrossentropyTest.java | 129 +++++++ .../SparseTopKCategoricalAccuracyTest.java | 96 ++++++ .../framework/metrics/SquaredHingeTest.java | 90 +++++ .../metrics/TopKCategoricalAccuracyTest.java | 103 ++++++ 42 files changed, 3787 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 81d9e13c8a9..3894bee0d0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -36,6 +36,9 @@ public class Losses { /** Default Fuzz factor. */ public static final float EPSILON = 1e-7f; + public static final int CHANNELS_LAST = -1; + public static final int CHANNELS_FIRST = 1; + /** * Calculates the mean absolute error between labels and predictions. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java new file mode 100644 index 00000000000..d13d20bfdee --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the binary cross-entropy loss between true labels and predicted labels. + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ +public class BinaryCrossentropy + extends MeanMetricWrapper implements LossInterface { + + private final boolean fromLogits; + private final float labelSmoothing; + + /** + * Creates a BinaryCrossentropy metric + * + *

This is the crossentropy metric class to be used when there are only two label classes (0 + * and 1). + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryCrossentropy( + Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + this.fromLogits = fromLogits; + this.labelSmoothing = labelSmoothing; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java new file mode 100644 index 00000000000..cf9ecd0858a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the categorical cross-entropy loss between true labels and predicted labels. + * + *

This is the crossentropy metric class to be used when there are multiple label classes (2 or + * more). Here we assume that labels are given as a one_hot representation. eg., When labels values + * are [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * . + */ +public class CategoricalCrossentropy + extends MeanMetricWrapper implements LossInterface { + + private final boolean fromLogits; + private final float labelSmoothing; + private int axis; + + /** + * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the + * labels and predictions. + * + *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, + * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CategoricalCrossentropy( + Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { + this(tf, name, fromLogits, labelSmoothing, Losses.CHANNELS_LAST, seed, type); + } + + /** + * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the + * labels and predictions. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, + * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 + * @param axis Int specifying the channels axis. axis={@link Losses#CHANNELS_LAST} + * corresponds to data format channels_last, and + * axis={@link Losses#CHANNELS_FIRST} corresponds to data format + * channels_first. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CategoricalCrossentropy( + Ops tf, + String name, + boolean fromLogits, + float labelSmoothing, + int axis, + long seed, + Class type) { + super(tf, name, seed, type); + setLoss(this); + this.fromLogits = fromLogits; + this.labelSmoothing = labelSmoothing; + this.axis = axis; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.categoricalCrossentropy( + getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java new file mode 100644 index 00000000000..a9500b79d9e --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the categorical hinge loss metric between labels and predictions. */ +public class CategoricalHinge extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a CategoricalHinge metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CategoricalHinge(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.categoricalHinge(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java new file mode 100644 index 00000000000..61802572c7b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the cosine similarity metric between labels and predictions. */ +// TODO: this is weird, the metric is called CosineSimilarity in Keras, +// but it calls Metrics.cosineProximity instead of Losses.cosineSimilarity. +// The metric is calculating the Euclidean distance using L2 norms, while the loss +// is using the dot product proportional to the product of their magnitudes. +// While the 2 concepts are similar, they are different. +// Should we rename this metric to CosineProximity? +public class CosineSimilarity extends MeanMetricWrapper + implements LossInterface { + public static final int[] DEFAULT_AXIS = {-1}; + private final int[] axis; + + /** + * Creates a CosineSimilarity metric with a default axis, {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CosineSimilarity(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_AXIS, seed, type); + } + + /** + * Creates a CosineSimilarity metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) { + super(tf, name, seed, type); + this.axis = axis; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity + return Metrics.cosineProximity(getTF(), labels, predictions, axis); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java new file mode 100644 index 00000000000..d655f8d8237 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the hinge loss metric between labels and predictions. */ +public class Hinge extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a Hinge metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public Hinge(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.hinge(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java new file mode 100644 index 00000000000..3f31383381a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes Computes Kullback-Leibler divergence loss metric between labels and predictions. */ +public class KLDivergence extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a KLDivergence metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public KLDivergence(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java new file mode 100644 index 00000000000..7d4b8a9fad7 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the logarithm of the hyperbolic cosine of the prediction error metric between labels and + * predictions. + */ +public class LogCoshError extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a LogCoshError metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public LogCoshError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.logCosh(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java new file mode 100644 index 00000000000..08d1083dd05 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.Reduce; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Represents a Metric that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } + * + * @param The data type for the metric values + * @param The data type for the metric result + */ +public class Mean extends Reduce { + + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected Mean(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, MetricReduction.WEIGHTED_MEAN, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java new file mode 100644 index 00000000000..6b29c72fe82 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the mean of absolute difference between labels and predictions. */ +public class MeanAbsoluteError extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanAbsoluteError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java new file mode 100644 index 00000000000..6209245d881 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the mean of absolute difference between labels and predictions. */ +public class MeanAbsolutePercentageError + extends MeanMetricWrapper implements LossInterface { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java new file mode 100644 index 00000000000..ce30e378e8d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the mean of absolute difference between labels and predictions. */ +public class MeanSquaredError extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public MeanSquaredError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanSquaredError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java new file mode 100644 index 00000000000..9baeac2f320 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the mean of absolute difference between labels and predictions. */ +public class MeanSquaredLogarithmicError + extends MeanMetricWrapper implements LossInterface { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java new file mode 100644 index 00000000000..62ec5439269 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.metrics; + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.metrics.impl.MetricVariable; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; +import java.util.stream.Collectors; + +/** + * Base class for Metrics + * + * @param The data type for the metric values + * @param The data type for the metric result + */ +public abstract class Metric { + + /** variables are stored by ExecutionEnvironment, and then by an identifier name */ + protected static Map>> + variableMap = new WeakHashMap<>(); + /** The TensorFlow Ops */ + private final Ops tf; + /** The random number generator seed value */ + private final long seed; + + // TODO: how to handle variables across new ExecutionEnvironments. + // Metrics may be instantiated multiple times using the same variables, + // These variables become stale when a new ExecutionEnvironment is created + // (most commonly seen in Unit Tests), so the question is how to best handle this. + // Option 1, which is used here is to map the variables against an instance of + // an ExecutionEnvironment in a WeakHashMap, when a new ExecutionEnvironment is presented, the + // new + // variables are mapped to it. A WeakHashMap is used to throw away the old ExecutionEnvironment + // mappings, when the old ExecutionEnvironment is finalized. + // Option 2, keep an instance of the newly presented ExecutionEnvironment and if it changes, + // clear the variable maps. + // My guess is that in a non-unit test environment, only one ExecutionEnvironment will be used, + // I welcome thoughts on this. + /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ + private final String name; + + private final Class type; + + /** + * Creates a Metric with a name of {@link Class#getSimpleName()} } + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected Metric(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a Metric + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected Metric(Ops tf, String name, long seed, Class type) { + if (!tf.scope().env().isGraph()) + throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); + this.seed = seed; + this.name = name != null ? name : this.getClass().getSimpleName(); + this.tf = tf.withSubScope(this.name); + this.type = type; + } + + /** + * Creates a List of Operations to update the metric state based on input values. + * + *

This is an empty implementation that should be overridden in a subclass, if needed. + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return a List of Operations to update the metric state + */ + @SuppressWarnings({"unchecked", "unused"}) + public List updateStateList(Operand values, Operand sampleWeights) { + return Collections.EMPTY_LIST; + } + + /** + * Creates a List of Operations to update the metric state based on labels and predictions. + * + *

This is an empty implementation that should be overridden in a sub class, if needed. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for the sample weights + * @return a List of Operations to update the metric state + */ + @SuppressWarnings({"unchecked","unused"}) + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + return Collections.EMPTY_LIST; + } + + /** + * Creates a NoOp Operation with control dependencies to update the metric state + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return the Operation to update the metric state + */ + public final Op updateState(Operand values, Operand sampleWeights) { + List controlOps = updateStateList(values, sampleWeights); + return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } + + /** + * Creates a NoOp Operation with control dependencies to update the metric state + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights sample weights to be applied to values, may be null. + * @return the Operation to update the metric state + */ + public final Op updateState( + Operand labels, Operand predictions, Operand sampleWeights) { + List controlOps = updateStateList(labels, predictions, sampleWeights); + return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } + + /** + * Gets the current result of the metric + * + * @param tf the TensorFlow Ops used to create the result + * @return the result, possibly with control dependencies + */ + public abstract Operand result(Ops tf); + + /** + * Gets the current result of the metric using the metric's {@link #getTF()} + * + * @return the result, possibly with control dependencies + */ + public Operand result() { + return result(this.tf); + } + + /** + * Calls update state once, followed by a call to get the result + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return the result, possibly with control dependencies + */ + public final Operand callOnce( + Operand values, Operand sampleWeights) { + List controlOps = updateStateList(values, sampleWeights); + Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); + return result(ltf); + } + + /** + * Adds a variable to collect metric values + * + * @param variable the variable + * @param initializer the initializer for the variable, if null, then the default for floating + * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution + * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for + * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} + */ + protected void addVariable( + String varName, Variable variable, Initializer initializer) { + // TODO option 2 would be to keep track of tf.scope().env() and if it changes, clear to old Map. + Map> variables = + variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); + variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed)); + } + + /** + * Gets the list of added variables + * + * @return the list of added variables + */ + public List> getVariables() { + List> result = new ArrayList<>(); + Map> variables = variableMap.get(tf.scope().env()); + if (variables != null) variables.values().forEach(mv -> result.add(mv.getVariable())); + return result; + } + + /** + * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. + * + * @param varName the base name for the variable + * @return the formatted variable name + */ + protected String getVariableName(String varName) { + return String.format("%s_%s", this.name, varName); + } + + /** + * Gets an Operation that initializes the variables. + * + * @param subScopeName the sub scope name + * @return the Operation used to initialize the variables. + */ + public Op initialize(String subScopeName) { + + List initializeOperations = initializeVarsList(subScopeName); + return tf.withControlDependencies(initializeOperations).noOp(); + } + + /** + * Gets the list of Operations that initializes the variables + * + * @param subScopeName the sub scope name + * @return the list of Operations that initializes the variables + */ + @SuppressWarnings("unchecked") + private List initializeVarsList(String subScopeName) { + Map> variables = variableMap.get(tf.scope().env()); + if (variables != null) + return variables.values().stream() + .map(metricVariable -> variableAssign(subScopeName, metricVariable)) + .collect(Collectors.toList()); + else return Collections.EMPTY_LIST; + } + + /** + * Resets all variables to their initial state + * + * @return An Operation that sets all variables to their initial state + */ + public Op resetStates() { + return initialize("resetStates"); + } + + /** + * Assigns a value to a Variable + * + *

This assumes the variable has already been initialized + * + * @param subScopeName the subscope for creating the variable + * @param mv the metric value used to assign the initializer to the variable. + * @return the variable add operation with necessary control dependencies + */ + private Operand variableAssign( + String subScopeName, MetricVariable mv) { + return tf.withSubScope(subScopeName).assign(mv.getVariable(), mv.initialize()); + } + + /** + * Gets a stored variable by name, Variables are cached first by the TensorFlow Environment, then + * by a variable name. + * + * @param varName the name assigned to the variable + * @return the variable, or null if the variable is not found. + */ + public Variable getVariable(String varName) { + Map> variables = variableMap.get(tf.scope().env()); + if (variables == null) return null; + MetricVariable mv = variables.get(varName); + return mv != null ? mv.getVariable() : null; + } + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the name of this metric. + * + * @return the name of this metric + */ + public String getName() { + return name; + } + + public Class getType() { + return type; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java new file mode 100644 index 00000000000..d837ff626b3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** Defines the different types of metric reductions */ +public enum MetricReduction { + + /** Scalar sum of weighted values. */ + SUM, + /** Scalar sum of weighted values divided by number of elements. */ + SUM_OVER_BATCH_SIZE, + /** Scalar sum of weighted values divided by sum of weights. */ + WEIGHTED_MEAN +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java new file mode 100644 index 00000000000..f4282bfd0a9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -0,0 +1,192 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +/** Built-in metrics functions. */ +public class Metrics { + + public static final float L2_NORM_EPSILON = 1e-12f; + + /** + * Computes how often targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[][]
+   *                                    {{0, 0, 1}, {0, 1, 0}});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                                    {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.asOutput().shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Top K categorical accuracy value. + */ + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { + Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + return CastHelper.cast( + tf, + tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), + predictions.type()); + } + + /** + * Computes how often integer targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.asOutput().shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Sparse top K categorical accuracy value. + */ + @SuppressWarnings("unchecked") + public static Operand sparseTopKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, int k) { + Operand tLabels; + if (labels.type() != predictions.type()) + tLabels = CastHelper.cast(tf, labels, predictions.type()); + else tLabels = (Operand) labels; + + int predictionsRank = predictions.asOutput().shape().numDimensions(); + int labelsRank = tLabels.asOutput().shape().numDimensions(); + + Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + if (predictionsRank > 2) { + castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); + } + if (labelsRank > 1) { + tLabels = tf.shape.flatten(tLabels); + } + } + return CastHelper.cast( + tf, + tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), + predictions.type()); + } + + /** + * Computes the cosine similarity between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels The ground truth values. + * @param predictions The prediction values. + * @param axis The dimension along which the cosine similarity is computed. + * @param the data type for the labels + * @param the data type for the predictions and result + * @return Cosine similarity value. + */ + @SuppressWarnings("unchecked") + public static Operand cosineProximity( + Ops tf, Operand labels, Operand predictions, int[] axis) { + Operand labelsNorm; + if (labels.type() != predictions.type()) + labelsNorm = CastHelper.cast(tf, labels, predictions.type()); + else labelsNorm = (Operand) labels; + labelsNorm = l2Normalize(tf, labelsNorm, axis); + + Operand predictionsNorm = l2Normalize(tf, predictions, axis); + Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); + } + + /** + * Normalizes along dimension axis using an L2 norm with an epsilon of {@link + * #L2_NORM_EPSILON}. + * + *

For a 1-D tensor with axis = 0, computes + * + *

+   *       output = x / sqrt(max(sum(x**2), epsilon))
+   * 
+ * + *

For x with more dimensions, independently normalizes each 1-D slice along + * dimension axis. + * + * @param tf The TensorFlow ops + * @param x The operand to normalize + * @param axes Dimension(s) along which to normalize. + * @param The data type for x. + * @return the normalized values of x. + */ + // TODO this was tf.math.l2_normalize in TF Python + + public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { + return l2Normalize(tf, x, axes, L2_NORM_EPSILON); + } + + /** + * Normalizes along dimension axis using an L2 norm. + * + *

For a 1-D tensor with axis = 0, computes + * + *

+   *       output = x / sqrt(max(sum(x**2), epsilon))
+   * 
+ * + *

For x with more dimensions, independently normalizes each 1-D slice along + * dimension axis. + * + * @param tf The TensorFlow ops + * @param x The operand to normalize + * @param axes Dimension(s) along which to normalize. + * @param epsilon A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if + * `norm < sqrt(epsilon)`. + * @param The data type for the values. + * @return the normalized values of x. + */ + // TODO this was tf.math.l2_normalize in TF Python + public static Operand l2Normalize( + Ops tf, Operand x, int[] axes, float epsilon) { + Operand squareSum = + tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); + Operand y = + tf.math.rsqrt( + tf.math.maximum( + squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); + return tf.math.mul(x, y); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java new file mode 100644 index 00000000000..f5730b07f42 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the poisson loss metric between labels and predictions. */ +public class Poisson extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a Poisson metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public Poisson(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.poisson(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java new file mode 100644 index 00000000000..403e11af8c0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the sparse categorical cross-entropy loss between true labels and predicted labels. */ +public class SparseCategoricalCrossentropy + extends MeanMetricWrapper implements LossInterface { + + private final boolean fromLogits; + private final int axes; + + /** + * Creates a SparseCategoricalCrossentropy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param axes The dimension along which the entropy is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public SparseCategoricalCrossentropy( + Ops tf, String name, boolean fromLogits, int axes, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + this.fromLogits = fromLogits; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axes); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java new file mode 100644 index 00000000000..1412465bd89 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the poisson loss metric between labels and predictions. */ +public class SparseTopKCategoricalAccuracy + extends MeanMetricWrapper implements LossInterface { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Metrics.sparseTopKCategoricalAccuracy(getTF(), labels, predictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java new file mode 100644 index 00000000000..7ce8091f2a0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the squared hinge loss metric between labels and predictions. */ +public class SquaredHinge extends MeanMetricWrapper + implements LossInterface { + + /** + * Creates a SquaredHinge metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public SquaredHinge(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.squaredHinge(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java new file mode 100644 index 00000000000..3198ab0ee04 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the poisson loss metric between labels and predictions. */ +public class TopKCategoricalAccuracy + extends MeanMetricWrapper implements LossInterface { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Metrics.topKCategoricalAccuracy(getTF(), labels, predictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java new file mode 100644 index 00000000000..aadc211c3c4 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +/** + * Interface for Metrics that wrap Loss functions. + * + * @param The data type of the predictions. + */ +public interface LossInterface { + + /** + * Calculates the weighted loss between labels and predictions + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the labels. + * @return the loss + */ + Operand call(Operand labels, Operand predictions); +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java new file mode 100644 index 00000000000..5e0023c4dbe --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.Mean; +import org.tensorflow.framework.metrics.MetricReduction; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +/** + * Bridges a stateless loss function with the {@link Mean} metric using a reduction of {@link + * MetricReduction#WEIGHTED_MEAN}. + * + *

The loss function calculates the loss between the labels and predictions + * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the + * loss over many iterations or epochs + * + * @param the data type for the loss. + */ +public class MeanMetricWrapper extends Mean { + + /** The loss function interface */ + protected LossInterface loss; + + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN} + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + } + + /** + * Gets the loss function. + * + * @return the loss function. + */ + public LossInterface getLoss() { + return loss; + } + + /** + * Sets the Loss function for this wrapper. + * + * @param loss the loss function. + */ + public void setLoss(LossInterface loss) { + this.loss = loss; + } + + /** + * Creates Operations that update the state of the mean metric, by calling the loss function and + * passing the loss to the Mean metric to calculate the weighted mean of the loss over many + * iterations. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sampleWeights vector. If the shape of sampleWeights is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param the datatype of the predictions + * @return a List of control operations that updates the Mean state variables. + */ + public List updateLossStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + if (labels == null || predictions == null) + throw new IllegalArgumentException("missing required inputs for labels and predictions"); + + Class type = predictions.type(); + Operand tPredicitons = CastHelper.cast(getTF(), predictions, getType()); + + Operand losses = loss.call(labels, tPredicitons); + Operand uLossess = CastHelper.cast(getTF(), losses, type); + + return super.updateStateList(uLossess, sampleWeights); + } + + /** + * Creates a Control Operation that updates the state of the mean metric by calculating the loss + * between the labels and predictions and then applying a weighted mean + * metric across the multiple iterations. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sampleWeights vector. If the shape of sampleWeights is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param the datatype of the labels + * @return a NoOp with control dependencies that update the state of the mean metric. + */ + public final Op updateLossState( + Operand labels, Operand predictions, Operand sampleWeights) { + List controlOps = updateLossStateList(labels, predictions, sampleWeights); + return getTF().withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java new file mode 100644 index 00000000000..78d7459697c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.initializers.VarianceScaling; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TIntegral; +import org.tensorflow.types.family.TNumber; + +/** + * Helper class that holds a metric variable + * + * @param the data type of the variable + */ +// TODO handle distributed variables with VariableAggregation and VariableSynchronization +public class MetricVariable { + private final Variable variable; + private final Initializer initializer; + private final Ops tf; + private boolean initialized; + + /** + * Creates a Metric Variable + * + * @param tf the TensorFlow Ops + * @param variable the variable + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public MetricVariable(Ops tf, Variable variable, long seed) { + this(tf, variable, null, seed); + } + /** + * Creates a Metric Variable + * + * @param tf the TensorFlow Ops + * @param variable the variable + * @param initializer the initializer for the variable, if null, then the default for floating + * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution + * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for + * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + @SuppressWarnings("unchecked") + public MetricVariable(Ops tf, Variable variable, Initializer initializer, long seed) { + this.tf = tf; + this.variable = variable; + + Class type = variable.type(); + if (initializer == null) { + if (TFloating.class.isAssignableFrom(type)) { + this.initializer = + (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + } else if (TIntegral.class.isAssignableFrom(type)) { + this.initializer = new Zeros<>(tf); + } else { + throw new IllegalArgumentException( + String.format( + "An initializer for variable %s of type %s is required", + variable.toString(), type)); + } + } else { + this.initializer = initializer; + } + } + + /** + * Initializers the variable based on the initializer + * + * @return the initialized variable + */ + public Operand initialize() { + initialized = true; + return initializer.call(tf.constant(variable.asOutput().shape()), variable.type()); + } + + /** + * Gets the variable + * + * @return the variable + */ + public Variable getVariable() { + return variable; + } + + /** + * Gets the initializer + * + * @return the initializer + */ + public Initializer getInitializer() { + return initializer; + } + + /** + * Gets the value of initialized + * + * @return the value of initialized + */ + public boolean isInitialized() { + return initialized; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java new file mode 100644 index 00000000000..5395cccf4a7 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -0,0 +1,154 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * These are helper methods for Metrics and will be module private when Java modularity is applied + * to TensorFlow Java. These methods should not be used outside of the metrics packages. + */ +public class MetricsHelper { + private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = + "weights can not be broadcast to values."; + + /** + * Asserts that the sampleWeight can be broadcast to values + * + * @param tf the TensorFlow Ops + * @param sampleWeights the sample weights. + * @param values the values to which weights are applied. + * @return Operation raising InvalidArgumentError if sampleWeight + * has incorrect shape. no_op if static checks determine + * sampleWeight has correct shape. + * @param the type of Operand + * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. + */ + @SuppressWarnings("unchecked") + public static Op broadcastWeights( + Ops tf, Operand sampleWeights, Operand values) { + + Operand weightsShape = tf.shape(sampleWeights); + Operand weightsRank = tf.rank(sampleWeights); + Shape weightsShapeStatic = sampleWeights.asOutput().shape(); + int weightsRankStatic = weightsShapeStatic.numDimensions(); + + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Shape valuesShapeStatic = values.asOutput().shape(); + int valuesRankStatic = valuesShapeStatic.numDimensions(); + + if (weightsRankStatic != -1 && valuesRankStatic != -1) { + if (weightsRankStatic == 0) { + return tf.withSubScope("static_scalar_check_success") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + if (weightsRankStatic != valuesRankStatic) { + throw new IllegalArgumentException( + String.format( + "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + valuesRankStatic, + weightsRankStatic, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + + for (int i = 0; i < valuesRankStatic; i++) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + throw new IllegalArgumentException( + String.format( + "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + i, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + } + return tf.withSubScope("static_dims_check_success") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + // Dynamic checks. + Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); + List> data = + Arrays.asList( + tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), + tf.constant("weights.shape="), + weightsShape, + tf.constant("values.shape="), + valuesShape, + tf.constant("is_scalar="), + is_scalar); + + Operand isValidShape = + tf.select( + is_scalar, + is_scalar, + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); + + return tf.assertThat(isValidShape, data); + } + + /** + * Gets an operand that tests if the shapes have the same rank and valid dimensions. + * + * @param tf the TensorFlow Ops + * @param weightsRank the operand for the rank of the sample weights + * @param weightsShape the operand for the shape of the sample weights + * @param valuesRank the operand for the rank of the sample weights + * @param valuesShape the operand for the shape of the sample weights + * @param the data type for the operands + * @return a boolean operand to determine if the Shape is scalar or not. + */ + private static Operand hasValidNonscalarShape( + Ops tf, + Operand weightsRank, + Operand weightsShape, + Operand valuesRank, + Operand valuesShape) { + tf = tf.withSubScope("has_valid_nonscalar_shape"); + Operand isSameRank = tf.math.equal(valuesRank, weightsRank); + return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); + } + + /** + * Gets an operand that tests if the shapes have valid dimensions or not. + * + * @param tf the TensorFlow Ops + * @param weightsShape the operand for the shape of the sample weights + * @param valuesShape the operand for the shape of the sample weights + * @param the data type for the operands + * @return a boolean operand to determine if the shapes have valid dimensions or not. + */ + private static Operand hasValidDims( + Ops tf, Operand weightsShape, Operand valuesShape) { + tf = tf.withSubScope("has_invalid_dims"); + Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); + return tf.math.equal(CastHelper.cast(tf, tf.constant(0), diff.type()), diff); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java new file mode 100644 index 00000000000..d2c8b2dec93 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -0,0 +1,204 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.metrics.MetricReduction; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Encapsulates metrics that perform a reduce operation on the metric values. + * + * @param The data type for the metric values + */ +public abstract class Reduce extends Metric { + public static final String TOTAL = "total"; + public static final String COUNT = "count"; + protected final MetricReduction reduction; + private final String totalName; + private final String countName; + /** the variable that holds the total of the metric values */ + protected Variable total; + /** the variable that holds the count of the metric values */ + protected Variable count; + + protected boolean initialized; + + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected Reduce(Ops tf, String name, long seed, Class type) { + this(tf, name, seed, MetricReduction.SUM, type); + } + + /** + * @param tf The TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param reduction The type of metric reduction to apply + */ + protected Reduce(Ops tf, String name, long seed, MetricReduction reduction, Class type) { + super(tf, name, seed, type); + this.reduction = reduction; + this.totalName = this.getVariableName(TOTAL); + this.countName = this.getVariableName(COUNT); + setupVars(); + } + /** initialize the Variables */ + @SuppressWarnings("unchecked") + private void setupVars() { + Zeros fZeros = new Zeros<>(getTF()); + total = (Variable) getVariable(totalName); + if (total == null) { + total = getTF().withSubScope(totalName).variable(Shape.scalar(), getType()); + addVariable(totalName, total, fZeros); + } + if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE + || reduction == MetricReduction.WEIGHTED_MEAN) { + count = (Variable) getVariable(countName); + if (count == null) { + count = getTF().withSubScope(countName).variable(Shape.scalar(), getType()); + addVariable(countName, count, fZeros); + } + } + } + + /** + * Updates the metric variables based on the inputs. At least one input arg required for + * values, an optional additional input for the sampleWeights + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return the result with a control dependency on update state Operands + * @throws IllegalArgumentException if values is null + */ + @Override + public List updateStateList(Operand values, Operand sampleWeights) { + + if (values == null) throw new IllegalArgumentException("values is required."); + List updateOperations = new ArrayList<>(); + // cast everything to match the variables + + Operand tValues = CastHelper.cast(getTF(), values, getType()); + Operand tSampleWeights = sampleWeights; + if (sampleWeights != null) { + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, sampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); + Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); + tValues = + getTF() + .withSubScope("broadcastWeightsCheck") + .withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) + .math + .mul(tValues, tSampleWeights); + } + + Operand valueSum = getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); + Operand totalUpdate = + getTF().assignAdd(total, CastHelper.cast(getTF(), valueSum, total.type())); + updateOperations.add(totalUpdate); + Operand numValues; + if (reduction != MetricReduction.SUM) { + switch (reduction) { + case SUM_OVER_BATCH_SIZE: + numValues = + CastHelper.cast( + getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); + break; + case WEIGHTED_MEAN: + if (tSampleWeights == null) { + numValues = + CastHelper.cast( + getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); + } else { + numValues = + CastHelper.cast( + getTF(), + getTF() + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), + getType()); + } + break; + default: + throw new UnsupportedOperationException( + String.format("reduction [%s] not implemented", reduction)); + } + Operand totalCount = getTF().assignAdd(this.count, numValues); + + updateOperations.add(totalCount); + } + + return updateOperations; + } + + /** {@inheritDoc} */ + @Override + public Operand result(Ops rtf) { + Operand fResult; + + switch (this.reduction) { + case SUM: + fResult = rtf.identity(total); + break; + case WEIGHTED_MEAN: + case SUM_OVER_BATCH_SIZE: + fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, getType())); + break; + default: + throw new UnsupportedOperationException( + String.format("reduction [%s] not implemented", reduction)); + } + return fResult; + } + + /** + * Gets the total variable + * + * @return the total variable + */ + public Variable getTotal() { + return total; + } + + /** + * Gets the count variable + * + * @return the count variable + */ + public Variable getCount() { + return count; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java new file mode 100644 index 00000000000..1f07b9567cb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class BinaryCrossentropyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 0, 1, 0}; + float[] predictionArray = {1, 1, 1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); + Op op = instance.updateState(labels, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.666619F, total); + session.evaluate(2, count); + session.evaluate(3.833309F, result); + } + } + + @Test + public void testUnweightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); + session.run(instance.resetStates()); + double[] trueArray = {1, 0, 1, 0, 1, 1}; + double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, logits, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(66.66667, total); + session.evaluate(2, count); + session.evaluate(33.333332, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 0, 1, 0}; + float[] predictionArray = {1, 1, 1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); + Operand sampleWeight = tf.constant(new float[] {1.5f, 2.f}); + Op op = instance.updateState(labels, yPrediction, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(11.499929f, total); + session.evaluate(3.5f, count); + session.evaluate(3.285694f, result); + } + } + + @Test + public void testWeightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); + session.run(instance.resetStates()); + double[] trueArray = {1, 0, 1, 0, 1, 1}; + double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(new double[] {2, 2.5}); + + Op op = instance.updateState(labels, logits, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(166.66666, total); + session.evaluate(4.5, count); + session.evaluate(37.037033, result); + } + } + + @Test + public void testLabelSmoothing() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float labelSmoothing = 0.1F; + BinaryCrossentropy instance = + new BinaryCrossentropy<>( + tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); + session.run(instance.resetStates()); + double[] trueArray = {1, 0, 1}; + double[] logitsArray = {100., -100., -100.}; + Operand labels = tf.constant(trueArray); + Operand logits = tf.constant(logitsArray); + + Op op = instance.updateState(labels, logits, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + + session.evaluate(35, total); + session.evaluate(1, count); + session.evaluate(35, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java new file mode 100644 index 00000000000..2b4a1d75467 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -0,0 +1,151 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class CategoricalCrossentropyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.3538785, total); + session.evaluate(2, count); + session.evaluate(1.1769392, result); + } + } + + @Test + public void testUnweightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.0022807, total); + session.evaluate(2, count); + session.evaluate(3.5011404, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {0.05f, 0.95f, 0f, 0.1f, 0.8f, 0.1f}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(new double[] {1.5f, 2.}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.6821095, total); + session.evaluate(3.5, count); + session.evaluate(1.3377455, result); + } + } + + @Test + public void testWeightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(new double[] {1.5, 2.f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(14.004333, total); + session.evaluate(3.5, count); + session.evaluate(4.0012328, result); + } + } + + @Test + public void testLabelSmoothing() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float labelSmoothing = 0.1F; + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.3356137, total); + session.evaluate(2, count); + session.evaluate(3.6678069, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java new file mode 100644 index 00000000000..87248d95e48 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class CategoricalHingeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalHinge instance = + new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2., total); + session.evaluate(4, count); + session.evaluate(0.5, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalHinge instance = + new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(3.5F, total); + session.evaluate(7, count); + session.evaluate(0.5, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java new file mode 100644 index 00000000000..848e2051af3 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -0,0 +1,101 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +class CosineSimilarityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CosineSimilarity instance = + new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.3744381F, total); + session.evaluate(2, count); + session.evaluate(0.18721905F, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CosineSimilarity instance = + new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new float[] {1.2f, 3.4f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(-0.3119840621948241F, total); + session.evaluate(4.6, count); + session.evaluate(-0.06782262221626612F, result); + } + } + + @Test + public void test_axis() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + int[] axis = new int[] {1}; + CosineSimilarity instance = + new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.3744381F, total); + session.evaluate(2, count); + session.evaluate(0.18721905F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java new file mode 100644 index 00000000000..6af5fed4889 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class HingeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Hinge instance = + new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.0125, total); + session.evaluate(2, count); + session.evaluate(.5062500, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Hinge instance = + new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + -1, 1, -1, 1, + -1, -1, 1, 1 + }; + float[] predArray = { + -0.3f, 0.2f, -0.1f, 1.6f, + -0.25f, -1.f, 0.5f, 0.6f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + + Operand sampleWeight = tf.constant(new double[] {1.5, 2.}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.7250f, total); + session.evaluate(3.5, count); + session.evaluate(.49285714f, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java new file mode 100644 index 00000000000..bf98ec4eba4 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class KLDivergenceTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + KLDivergence instance = + new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; + float[][] predArray = {{.4f, .9f, .12f}, {.36f, .3f, .4f}}; + Operand labels = tf.constant(trueArray); + Operand predictions = tf.constant(predArray); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.1921477, total); + session.evaluate(2, count); + session.evaluate(0.5960738, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + KLDivergence instance = + new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = { + .5f, .8f, .12f, + .7f, .43f, .8f + }; + float[] predArray = { + .4f, .9f, .12f, + .36f, .3f, .4f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new double[] {1.2, 3.4}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.015142, total); + session.evaluate(4.6, count); + session.evaluate(0.872857, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java new file mode 100644 index 00000000000..31c043e0473 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class LogCoshErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + LogCoshError instance = + new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.829245, result); + session.evaluate(9.65849, total); + session.evaluate(2, count); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + LogCoshError instance = + new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + double[][] sampleArray = {{1.2}, {3.4}}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(sampleArray); + + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(5.2178759, result); + session.evaluate(24.002228, total); + session.evaluate(4.6, count); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java new file mode 100644 index 00000000000..73241ecbe9f --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class MeanAbsoluteErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanAbsoluteError instance = + new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0f, instance.getTotal()); + session.evaluate(0f, instance.getCount()); + session.evaluate(0.f, instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(yTrue, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.0, total); + session.evaluate(4, count); + session.evaluate(0.5, result); + + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanAbsoluteError instance = + new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(3.8, total); + session.evaluate(7, count); + session.evaluate(0.54285, result); + + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java new file mode 100644 index 00000000000..4c92844b217 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +class MeanAbsolutePercentageErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + session.setEpsilon(1E-6f); + Ops tf = session.getTF(); + MeanAbsolutePercentageError instance = + new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.evaluate(0.0f, instance.getTotal()); + session.evaluate(0f, instance.getCount()); + session.evaluate(0.f, instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Op op = instance.updateState(yTrue, yPrediction, null); + + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.4E9f, total); + session.evaluate(4f, count); + session.evaluate(35e7f, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + session.setEpsilon(1E-6f); + Ops tf = session.getTF(); + MeanAbsolutePercentageError instance = + new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0, instance.getCount()); + + long[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1.f, 1.5f, 2.f, 2.5f}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.800000067278928E9, total); + session.evaluate(7, count); + session.evaluate(4.000000096112754E8, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java new file mode 100644 index 00000000000..0b760213015 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +class MeanSquaredErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredError instance = + new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(yTrue, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.0, total); + session.evaluate(4, count); + session.evaluate(0.5, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredError instance = + new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + long[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(3.8, total); + session.evaluate(7, count); + session.evaluate(0.542857, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java new file mode 100644 index 00000000000..098a5cb9725 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class MeanSquaredLogarithmicErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredLogarithmicError instance = + new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.evaluate(0.0f, instance.getTotal()); + session.evaluate(0f, instance.getCount()); + session.evaluate(0.f, instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(yTrue, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.96090573f, total); + session.evaluate(4f, count); + session.evaluate(0.24022f, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredLogarithmicError instance = + new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.8257208, total); + session.evaluate(7, count); + session.evaluate(0.26082, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java new file mode 100644 index 00000000000..cf3c3e44719 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class PoissonTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Poisson instance = + new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {4, 8, 12, 8, 1, 3}; + float[] predArray = {1, 9, 2, 5, 2, 6}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(-6.6131644, total); + session.evaluate(2, count); + session.evaluate(-3.3065822, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Poisson instance = + new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {4, 8, 12, 8, 1, 3}; + float[] predArray = {1, 9, 2, 5, 2, 6}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new float[] {1.2f, 3.4f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(-12.29468f, total); + session.evaluate(4.6f, count); + session.evaluate(-2.6727562f, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java new file mode 100644 index 00000000000..87af1bd8448 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class SparseCategoricalCrossentropyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] predictionArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand predictions = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.3538785, total); + session.evaluate(2, count); + session.evaluate(1.1769392, result); + } + } + + @Test + public void testUnweightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] logitsArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, logits, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.002277, total); + session.evaluate(2, count); + session.evaluate(3.501135, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] predictionArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand predictions = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new float[] {1.5F, 2.F}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.6821103f, total); + session.evaluate(3.5f, count); + session.evaluate(1.3377458f, result); + } + } + + @Test + public void testWeightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] predictionArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand predictions = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new double[] {1.5, 2}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(14.004333, total); + session.evaluate(3.5, count); + session.evaluate(4.001232, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..4a0cdefe492 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class SparseTopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseTopKCategoricalAccuracy instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testCorrectness", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new double[] {2, 1}); + Operand predictions = + tf.constant(new double[][] {{0.1, 0.9, 0.8}, {0.05, 0.95, 0}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testCorrectness", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + predictions = + tf.constant( + new double[][] { + {0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4}, + {0.05, 0.95, 0, 0, 0, 0, 0} + }); + instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testCorrectness", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseTopKCategoricalAccuracy instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new int[] {1, 0, 2}); + Operand predictions = + tf.constant( + new double[][] { + {0, 0.9, 0.1}, + {0, 0.9, 0.1}, + {0, 0.9, 0.1} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java new file mode 100644 index 00000000000..e3376c224f3 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class SquaredHingeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SquaredHinge instance = + new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, + 0, 0, 1, 1 + }; + float[] predArray = { + -0.3f, 0.2f, -0.1f, 1.6f, + -0.25f, -1.f, 0.5f, 0.6f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.72812f, total); + session.evaluate(2f, count); + session.evaluate(0.3640625f, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SquaredHinge instance = + new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, + 0, 0, 1, 1 + }; + double[] predArray = { + -0.3, 0.2, -0.1, 1.6, + -0.25, -1., 0.5, 0.6 + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + + Operand sampleWeight = tf.constant(new double[] {1.5f, 2.f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.2137499, total); + session.evaluate(3.5, count); + session.evaluate(0.3467857, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..52ccde29196 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class TopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); + Operand predictions = + tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + labels = + tf.constant( + new float[][] { + {0, 0, 1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0} + }); + predictions = + tf.constant( + new double[][] { + {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, + {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} + }); + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = + tf.constant( + new double[][] { + {1, 0, 2}, + {1, 0, 0}, + {0, 0, 1} + }); + Operand predictions = + tf.constant( + new double[][] { + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} From 090bde4424deccadb7a0efb6c902377e410cb23d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 31 Dec 2020 20:03:32 -0500 Subject: [PATCH 02/56] Initial checkin and sync with master --- .../framework/metrics/BinaryCrossentropy.java | 2 +- .../metrics/CategoricalCrossentropy.java | 2 +- .../framework/metrics/CosineSimilarity.java | 16 ++- .../tensorflow/framework/metrics/Mean.java | 3 +- .../tensorflow/framework/metrics/Metric.java | 42 ++----- .../tensorflow/framework/metrics/Metrics.java | 19 ++-- .../SparseTopKCategoricalAccuracy.java | 65 ----------- .../metrics/TopKCategoricalAccuracy.java | 63 ----------- .../metrics/impl/MeanMetricWrapper.java | 36 ++---- .../metrics/impl/MetricVariable.java | 13 ++- .../framework/metrics/impl/MetricsHelper.java | 100 +++++++++++++++-- .../framework/metrics/impl/Reduce.java | 88 +++++++++------ .../metrics/BinaryCrossentropyTest.java | 17 +-- .../metrics/CosineSimilarityTest.java | 2 +- .../framework/metrics/KLDivergenceTest.java | 2 +- .../SparseTopKCategoricalAccuracyTest.java | 96 ---------------- .../metrics/TopKCategoricalAccuracyTest.java | 103 ------------------ 17 files changed, 211 insertions(+), 458 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index d13d20bfdee..41a5533b5d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -28,7 +28,7 @@ * @param The data type for the metric result */ public class BinaryCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossInterface { private final boolean fromLogits; private final float labelSmoothing; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index cf9ecd0858a..79481f608a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -34,7 +34,7 @@ public class CategoricalCrossentropy private final boolean fromLogits; private final float labelSmoothing; - private int axis; + private final int axis; /** * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 61802572c7b..4a5214aea8d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -27,9 +27,9 @@ // is using the dot product proportional to the product of their magnitudes. // While the 2 concepts are similar, they are different. // Should we rename this metric to CosineProximity? -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity extends MeanMetricWrapper implements LossInterface { - public static final int[] DEFAULT_AXIS = {-1}; + public static final int DEFAULT_AXIS = -1; private final int[] axis; /** @@ -44,6 +44,18 @@ public CosineSimilarity(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_AXIS, seed, type); } + /** + * Creates a CosineSimilarity metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { + this(tf, name, new int[] {axis}, seed, type); + } /** * Creates a CosineSimilarity metric * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index 08d1083dd05..c68a70902a7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -33,8 +33,9 @@ public class Mean extends Reduce { * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the result. */ protected Mean(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, MetricReduction.WEIGHTED_MEAN, type); + super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 62ec5439269..28a2ae0fa94 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -1,20 +1,3 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.framework.metrics; - /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ +package org.tensorflow.framework.metrics; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; @@ -74,17 +58,15 @@ public abstract class Metric { /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ private final String name; - private final Class type; - /** - * Creates a Metric with a name of {@link Class#getSimpleName()} } + * Creates a Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - protected Metric(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + protected Metric(Ops tf, long seed) { + this(tf, null, seed); } /** @@ -95,13 +77,12 @@ protected Metric(Ops tf, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - protected Metric(Ops tf, String name, long seed, Class type) { + protected Metric(Ops tf, String name, long seed) { if (!tf.scope().env().isGraph()) throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); this.tf = tf.withSubScope(this.name); - this.type = type; } /** @@ -113,8 +94,8 @@ protected Metric(Ops tf, String name, long seed, Class type) { * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + @SuppressWarnings({"unchecked","unused"}) + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -142,7 +123,7 @@ public List updateStateList( * @param sampleWeights sample weights to be applied to values, may be null. * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -153,6 +134,7 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for the sample weights * @return the Operation to update the metric state */ public final Op updateState( @@ -206,7 +188,7 @@ protected void addVariable( // TODO option 2 would be to keep track of tf.scope().env() and if it changes, clear to old Map. Map> variables = variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); - variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed)); + variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed, variable.type())); } /** @@ -313,8 +295,4 @@ public Ops getTF() { public String getName() { return name; } - - public Class getType() { - return type; - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index f4282bfd0a9..c3c44ef6134 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -40,7 +40,7 @@ public class Metrics { * {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); * Operand<TFloat32> m = Metrics.topKCategoricalAccuracy( * labels, predictions, 3) - * //m.asOutput().shape().toString == "[2]" + * //m.shape().toString == "[2]" * * * @param tf the TensorFlow Ops. @@ -71,7 +71,7 @@ public static Operand topKCategoricalA * {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}}); * Operand<TFloat32> m = Metrics.topKCategoricalAccuracy( * labels, predictions, 3) - * //m.asOutput().shape().toString == "[2]" + * //m.shape().toString == "[2]" * * * @param tf the TensorFlow Ops. @@ -90,8 +90,8 @@ public static Operand sparseTopKCatego tLabels = CastHelper.cast(tf, labels, predictions.type()); else tLabels = (Operand) labels; - int predictionsRank = predictions.asOutput().shape().numDimensions(); - int labelsRank = tLabels.asOutput().shape().numDimensions(); + int predictionsRank = predictions.shape().numDimensions(); + int labelsRank = tLabels.shape().numDimensions(); Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { @@ -152,9 +152,10 @@ public static Operand cosineProximity( * @param The data type for x. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python + // TODO this was tf.math.l2_normalize in TF Python, does it belong here? - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { + public static Operand l2Normalize( + Ops tf, Operand x, int[] axes) { return l2Normalize(tf, x, axes, L2_NORM_EPSILON); } @@ -178,15 +179,15 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param The data type for the values. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python + // TODO this was tf.math.l2_normalize in TF Python, does it belong here? public static Operand l2Normalize( Ops tf, Operand x, int[] axes, float epsilon) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); Operand y = tf.math.rsqrt( - tf.math.maximum( - squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); + tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); return tf.math.mul(x, y); } + } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java deleted file mode 100644 index 1412465bd89..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossInterface; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; -import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; - -/** Computes the poisson loss metric between labels and predictions. */ -public class SparseTopKCategoricalAccuracy - extends MeanMetricWrapper implements LossInterface { - public static final int DEFAULT_K = 5; - /** Number of top elements to look at for computing accuracy. */ - private final int k; - - /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of - * top elements to look at for computing accuracy. - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the date type for the result - */ - public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_K, seed, type); - } - - /** - * Creates a TopKCategoricalAccuracy metric - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param k Number of top elements to look at for computing accuracy. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the date type for the result - */ - public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { - super(tf, name, seed, type); - this.k = k; - setLoss(this); - } - - /** {@inheritDoc} */ - @Override - public Operand call(Operand labels, Operand predictions) { - return Metrics.sparseTopKCategoricalAccuracy(getTF(), labels, predictions, k); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java deleted file mode 100644 index 3198ab0ee04..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossInterface; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; -import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; - -/** Computes the poisson loss metric between labels and predictions. */ -public class TopKCategoricalAccuracy - extends MeanMetricWrapper implements LossInterface { - public static final int DEFAULT_K = 5; - /** Number of top elements to look at for computing accuracy. */ - private final int k; - - /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of - * top elements to look at for computing accuracy. - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - */ - public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_K, seed, type); - } - - /** - * Creates a TopKCategoricalAccuracy metric - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param k Number of top elements to look at for computing accuracy. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - */ - public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { - super(tf, name, seed, type); - this.k = k; - setLoss(this); - } - - /** {@inheritDoc} */ - @Override - public Operand call(Operand labels, Operand predictions) { - return Metrics.topKCategoricalAccuracy(getTF(), labels, predictions, k); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 5e0023c4dbe..77566e5c400 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -32,7 +32,8 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the loss. + * @param the data type for the predictions. + * @param The data type for the metric result */ public class MeanMetricWrapper extends Mean { @@ -86,40 +87,17 @@ public void setLoss(LossInterface loss) { * @param the datatype of the predictions * @return a List of control operations that updates the Mean state variables. */ - public List updateLossStateList( + public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { if (labels == null || predictions == null) throw new IllegalArgumentException("missing required inputs for labels and predictions"); - Class type = predictions.type(); - Operand tPredicitons = CastHelper.cast(getTF(), predictions, getType()); + Operand tLabels = CastHelper.cast(getTF(), labels, getType()); + Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); - Operand losses = loss.call(labels, tPredicitons); - Operand uLossess = CastHelper.cast(getTF(), losses, type); - return super.updateStateList(uLossess, sampleWeights); - } + Operand losses = loss.call(tLabels, tPredictions); - /** - * Creates a Control Operation that updates the state of the mean metric by calculating the loss - * between the labels and predictions and then applying a weighted mean - * metric across the multiple iterations. - * - * @param labels the truth values or labels - * @param predictions the predictions - * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor - * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sampleWeights vector. If the shape of sampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss - * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @return a NoOp with control dependencies that update the state of the mean metric. - */ - public final Op updateLossState( - Operand labels, Operand predictions, Operand sampleWeights) { - List controlOps = updateLossStateList(labels, predictions, sampleWeights); - return getTF().withSubScope("updateState").withControlDependencies(controlOps).noOp(); + return super.updateStateList(CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index 78d7459697c..cb5e987b4cf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -45,8 +45,8 @@ public class MetricVariable { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - public MetricVariable(Ops tf, Variable variable, long seed) { - this(tf, variable, null, seed); + public MetricVariable(Ops tf, Variable variable, long seed, Class type) { + this(tf, variable, null, seed, type); } /** * Creates a Metric Variable @@ -61,13 +61,14 @@ public MetricVariable(Ops tf, Variable variable, long seed) { * will always produce the same random tensor for a given shape and data type. */ @SuppressWarnings("unchecked") - public MetricVariable(Ops tf, Variable variable, Initializer initializer, long seed) { + public MetricVariable( + Ops tf, Variable variable, Initializer initializer, long seed, Class type) { this.tf = tf; this.variable = variable; - Class type = variable.type(); if (initializer == null) { if (TFloating.class.isAssignableFrom(type)) { + //noinspection RedundantCast this.initializer = (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); } else if (TIntegral.class.isAssignableFrom(type)) { @@ -76,7 +77,7 @@ public MetricVariable(Ops tf, Variable variable, Initializer initializer, throw new IllegalArgumentException( String.format( "An initializer for variable %s of type %s is required", - variable.toString(), type)); + variable.toString(), type.getSimpleName())); } } else { this.initializer = initializer; @@ -90,7 +91,7 @@ public MetricVariable(Ops tf, Variable variable, Initializer initializer, */ public Operand initialize() { initialized = true; - return initializer.call(tf.constant(variable.asOutput().shape()), variable.type()); + return initializer.call(tf.constant(variable.shape()), variable.type()); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 5395cccf4a7..042badbb615 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,24 +15,30 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * These are helper methods for Metrics and will be module private when Java modularity is applied * to TensorFlow Java. These methods should not be used outside of the metrics packages. */ public class MetricsHelper { - private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = + public static final float NEG_INF = -1e10f; + private static final String ASSERT_BROADCAST_ERROR_PREFIX = "weights can not be broadcast to values."; /** @@ -53,12 +59,12 @@ public static Op broadcastWeights( Operand weightsShape = tf.shape(sampleWeights); Operand weightsRank = tf.rank(sampleWeights); - Shape weightsShapeStatic = sampleWeights.asOutput().shape(); + Shape weightsShapeStatic = sampleWeights.shape(); int weightsRankStatic = weightsShapeStatic.numDimensions(); Operand valuesShape = tf.shape(values); Operand valuesRank = tf.rank(values); - Shape valuesShapeStatic = values.asOutput().shape(); + Shape valuesShapeStatic = values.shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); if (weightsRankStatic != -1 && valuesRankStatic != -1) { @@ -71,7 +77,7 @@ public static Op broadcastWeights( throw new IllegalArgumentException( String.format( "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, + ASSERT_BROADCAST_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, valuesShapeStatic.toString(), @@ -83,7 +89,7 @@ public static Op broadcastWeights( throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, + ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic.toString(), weightsShapeStatic.toString())); @@ -97,7 +103,7 @@ public static Op broadcastWeights( Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( - tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), + tf.constant(ASSERT_BROADCAST_ERROR_PREFIX), tf.constant("weights.shape="), weightsShape, tf.constant("values.shape="), @@ -111,7 +117,7 @@ public static Op broadcastWeights( is_scalar, hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); - return tf.assertThat(isValidShape, data); + return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); } /** @@ -149,6 +155,82 @@ private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("has_invalid_dims"); Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); - return tf.math.equal(CastHelper.cast(tf, tf.constant(0), diff.type()), diff); + return tf.math.equal(cast(tf, tf.constant(0), diff.asOutput().type()), diff); + } + + // alias for mean + + /** + * Calculate the mean of the operand, along all axes and keepDims is false + * + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param the type of the Operand. + * @return the mean of the operand + */ + public static Operand mean(Ops tf, Operand x) { + return mean(tf, x, null, false); + } + + /** + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axis Axes to compute the mean. + * @param the type of the Operand. + * @param the type of the axis. + * @return the mean of the operand, alongside the specified axis. + */ + public static Operand mean( + Ops tf, Operand x, Operand axis) { + return mean(tf, x, axis, false); + } + + /** + * Calculate the mean of the operand, along all axis. + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axis + * . If keepdims is true, the reduced dimensions are retained + * with length 1. + * @param the type of the operand + * @return the mean of elements of x. + */ + public static Operand mean(Ops tf, Operand x, boolean keepDims) { + return mean(tf, x, null, keepDims); + } + + /** + * Calculate the mean of the operand, alongside the specified axis. + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axis Axes to compute the mean. + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the Operand + * @param the data type of the axis + * @return the mean of elements of `x`. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Operand mean( + Ops tf, Operand x, Operand axis, boolean keepDims) { + // Cannot use generics here because xf may change from TBool to TFloat32 + Operand xf; + if (x.asOutput().type() == TBool.class) { + xf = tf.dtypes.cast(x, TFloat32.class); + } else { + xf = x; + } + if (axis == null) { + axis = allAxes(tf, xf); + } + return tf.math.mean(xf, axis, Mean.keepDims(keepDims)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index d2c8b2dec93..d3b7caa54cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -35,6 +35,7 @@ * Encapsulates metrics that perform a reduce operation on the metric values. * * @param The data type for the metric values + * @param The data type for the metric result */ public abstract class Reduce extends Metric { public static final String TOTAL = "total"; @@ -42,6 +43,8 @@ public abstract class Reduce extends Metri protected final MetricReduction reduction; private final String totalName; private final String countName; + + private final Class type; /** the variable that holds the total of the metric values */ protected Variable total; /** the variable that holds the count of the metric values */ @@ -56,23 +59,26 @@ public abstract class Reduce extends Metri * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ protected Reduce(Ops tf, String name, long seed, Class type) { - this(tf, name, seed, MetricReduction.SUM, type); + this(tf, name, MetricReduction.SUM, seed, type); } /** * @param tf The TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param reduction The type of metric reduction to apply * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param reduction The type of metric reduction to apply + * @param type the type for the variables and result */ - protected Reduce(Ops tf, String name, long seed, MetricReduction reduction, Class type) { - super(tf, name, seed, type); + protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class type) { + super(tf, name, seed); this.reduction = reduction; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); + this.type = type; setupVars(); } /** initialize the Variables */ @@ -81,14 +87,14 @@ private void setupVars() { Zeros fZeros = new Zeros<>(getTF()); total = (Variable) getVariable(totalName); if (total == null) { - total = getTF().withSubScope(totalName).variable(Shape.scalar(), getType()); + total = getTF().withSubScope(totalName).variable(Shape.scalar(), type); addVariable(totalName, total, fZeros); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { count = (Variable) getVariable(countName); if (count == null) { - count = getTF().withSubScope(countName).variable(Shape.scalar(), getType()); + count = getTF().withSubScope(countName).variable(Shape.scalar(), type); addVariable(countName, count, fZeros); } } @@ -104,29 +110,48 @@ private void setupVars() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { if (values == null) throw new IllegalArgumentException("values is required."); List updateOperations = new ArrayList<>(); // cast everything to match the variables + Operand lSampleWeights = null; + Operand lValues = values; - Operand tValues = CastHelper.cast(getTF(), values, getType()); - Operand tSampleWeights = sampleWeights; if (sampleWeights != null) { - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, sampleWeights); - tValues = tuple.getTarget(); - tSampleWeights = tuple.getSampleWeights(); - Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); - tValues = - getTF() - .withSubScope("broadcastWeightsCheck") - .withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) - .math - .mul(tValues, tSampleWeights); + lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); + lValues = tuple.getTarget(); + lSampleWeights = tuple.getSampleWeights(); + // lSampleWeights = WeightsBroadcastOps.broadcastWeights(getTF(), lSampleWeights, lValues); + try { + + Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + lValues = + getTF() + .withSubScope("broadcastWeightsCheck") + .withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) + .math + .mul(lValues, lSampleWeights); + } catch (IllegalArgumentException ex) { + System.out.println("Reduce: Fall back from broadcast"); + // reduce the values down to the rank of the samples + int nDim = lValues.shape().numDimensions(); + int wDim = lSampleWeights.shape().numDimensions(); + int numAxes = nDim - wDim; + int[] axes = new int[numAxes]; + for (int i = 0; i < numAxes; i++) axes[i] = i + wDim; + if (reduction == MetricReduction.SUM) { + lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + } else { + lValues = getTF().math.mean(lValues, getTF().constant(axes)); + } + lValues = getTF().math.mul(lValues, lSampleWeights); + } } - Operand valueSum = getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); + Operand valueSum = getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); Operand totalUpdate = getTF().assignAdd(total, CastHelper.cast(getTF(), valueSum, total.type())); updateOperations.add(totalUpdate); @@ -134,22 +159,18 @@ public List updateStateList(Operand values, Operand< if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast( - getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); break; case WEIGHTED_MEAN: - if (tSampleWeights == null) { - numValues = - CastHelper.cast( - getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); + if (lSampleWeights == null) { + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); } else { numValues = CastHelper.cast( getTF(), getTF() - .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), - getType()); + .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + type); } break; default: @@ -175,7 +196,7 @@ public Operand result(Ops rtf) { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, getType())); + fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, type)); break; default: throw new UnsupportedOperationException( @@ -201,4 +222,9 @@ public Variable getTotal() { public Variable getCount() { return count; } + + /** Gets the type for the variables */ + public Class getType() { + return type; + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 1f07b9567cb..0529026ce8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -23,6 +23,7 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; class BinaryCrossentropyTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -57,9 +58,9 @@ public void testUnweightedLogits() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); - double[] trueArray = {1, 0, 1, 0, 1, 1}; + float[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Op op = instance.updateState(labels, logits, null); session.run(op); @@ -79,9 +80,9 @@ public void testWeighted() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); - float[] trueArray = {1, 0, 1, 0}; + int[] trueArray = {1, 0, 1, 0}; float[] predictionArray = {1, 1, 1, 0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(new float[] {1.5f, 2.f}); @@ -104,9 +105,9 @@ public void testWeightedLogits() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); - double[] trueArray = {1, 0, 1, 0, 1, 1}; + float[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {2, 2.5}); @@ -131,9 +132,9 @@ public void testLabelSmoothing() { new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); - double[] trueArray = {1, 0, 1}; + float[] trueArray = {1, 0, 1}; double[] logitsArray = {100., -100., -100.}; - Operand labels = tf.constant(trueArray); + Operand labels = tf.constant(trueArray); Operand logits = tf.constant(logitsArray); Op op = instance.updateState(labels, logits, null); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index 848e2051af3..a9721ef2f8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -79,7 +79,7 @@ public void testWeighted() { public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - int[] axis = new int[] {1}; + int axis = 1; CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index bf98ec4eba4..28020c0fa1c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -69,7 +69,7 @@ public void testWeighted() { Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(new double[] {1.2, 3.4}); + Operand sampleWeight = tf.constant(new double[][] {{1.2}, {3.4}}); Op op = instance.updateState(labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java deleted file mode 100644 index 4a0cdefe492..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -class SparseTopKCategoricalAccuracyTest { - private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - - @Test - public void testCorrectness() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - SparseTopKCategoricalAccuracy instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testCorrectness", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - - Operand labels = tf.constant(new double[] {2, 1}); - Operand predictions = - tf.constant(new double[][] {{0.1, 0.9, 0.8}, {0.05, 0.95, 0}}); - - Op update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(1., instance.result()); - - // With `k` < 5. - instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testCorrectness", 1, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - - // With `k` > 5. - predictions = - tf.constant( - new double[][] { - {0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4}, - {0.05, 0.95, 0, 0, 0, 0, 0} - }); - instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testCorrectness", 6, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - } - } - - @Test - public void testWeighted() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - SparseTopKCategoricalAccuracy instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testWeighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - - Operand labels = tf.constant(new int[] {1, 0, 2}); - Operand predictions = - tf.constant( - new double[][] { - {0, 0.9, 0.1}, - {0, 0.9, 0.1}, - {0, 0.9, 0.1} - }); - - Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); - - Op update = instance.updateState(labels, predictions, sampleWeight); - session.run(update); - session.evaluate(1., instance.result()); - } - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java deleted file mode 100644 index 52ccde29196..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; - -class TopKCategoricalAccuracyTest { - private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - - @Test - public void testCorrectness() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - TopKCategoricalAccuracy instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); - Operand predictions = - tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); - - Op update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(1., instance.result()); - - // With `k` < 5. - instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - - // With `k` > 5. - labels = - tf.constant( - new float[][] { - {0, 0, 1, 0, 0, 0, 0}, - {0, 1, 0, 0, 0, 0, 0} - }); - predictions = - tf.constant( - new double[][] { - {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, - {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} - }); - instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - } - } - - @Test - public void testWeighted() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - TopKCategoricalAccuracy instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - - Operand labels = - tf.constant( - new double[][] { - {1, 0, 2}, - {1, 0, 0}, - {0, 0, 1} - }); - Operand predictions = - tf.constant( - new double[][] { - {0f, 0.9f, 0.1f}, - {0f, 0.9f, 0.1f}, - {0f, 0.9f, 0.1f} - }); - - Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); - - Op update = instance.updateState(labels, predictions, sampleWeight); - session.run(update); - session.evaluate(1., instance.result()); - } - } -} From ce5fa27abc587827d43e8e9962dee56a4678e509 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 1 Jan 2021 09:04:32 -0500 Subject: [PATCH 03/56] Initial checkin and sync with master --- .../tensorflow/framework/metrics/BinaryCrossentropyTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 0529026ce8f..7ceedded018 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -45,9 +45,9 @@ public void testUnweighted() { Variable total = instance.getTotal(); Variable count = instance.getCount(); Operand result = instance.result(); - session.evaluate(7.666619F, total); + session.evaluate(7.71247434F, total); session.evaluate(2, count); - session.evaluate(3.833309F, result); + session.evaluate(3.85623717F, result); } } From 475ca362f9c5ccd69f530eddceeb1e08d59bada4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 1 Jan 2021 09:46:18 -0500 Subject: [PATCH 04/56] JavaDoc cleanup --- .../framework/metrics/BinaryCrossentropy.java | 4 +- .../metrics/CategoricalCrossentropy.java | 11 +++- .../framework/metrics/CategoricalHinge.java | 8 ++- .../framework/metrics/CosineSimilarity.java | 19 +++--- .../tensorflow/framework/metrics/Hinge.java | 8 ++- .../framework/metrics/KLDivergence.java | 9 ++- .../framework/metrics/LogCoshError.java | 8 ++- .../tensorflow/framework/metrics/Mean.java | 4 +- .../framework/metrics/MeanAbsoluteError.java | 8 ++- .../metrics/MeanAbsolutePercentageError.java | 8 ++- .../framework/metrics/MeanSquaredError.java | 8 ++- .../metrics/MeanSquaredLogarithmicError.java | 8 ++- .../tensorflow/framework/metrics/Metric.java | 14 ----- .../tensorflow/framework/metrics/Metrics.java | 59 +------------------ .../tensorflow/framework/metrics/Poisson.java | 8 ++- .../SparseCategoricalCrossentropy.java | 9 ++- .../framework/metrics/SquaredHinge.java | 8 ++- .../metrics/impl/MeanMetricWrapper.java | 9 +-- .../metrics/impl/MetricVariable.java | 4 +- 19 files changed, 110 insertions(+), 104 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 41a5533b5d1..2372293d0d3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -22,7 +22,7 @@ import org.tensorflow.types.family.TNumber; /** - * Computes the binary cross-entropy loss between true labels and predicted labels. + * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * * @param the data type for the predictions. * @param The data type for the metric result @@ -48,7 +48,7 @@ public class BinaryCrossentropy * correspond to heavier smoothing. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the data type for the variables + * @param type the type for the variables and result */ public BinaryCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 79481f608a1..6bfd471401b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -22,12 +22,16 @@ import org.tensorflow.types.family.TNumber; /** - * Computes the categorical cross-entropy loss between true labels and predicted labels. + * A Metric that computes the categorical cross-entropy loss between true labels and predicted + * labels. * *

This is the crossentropy metric class to be used when there are multiple label classes (2 or - * more). Here we assume that labels are given as a one_hot representation. eg., When labels values - * are [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * more). The labels should be given as a one_hot representation. eg., When labels values are + * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . + * + * @param the data type for the predictions. + * @param The data type for the metric result */ public class CategoricalCrossentropy extends MeanMetricWrapper implements LossInterface { @@ -51,6 +55,7 @@ public class CategoricalCrossentropy * for label 1 * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CategoricalCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index a9500b79d9e..21f19d88ade 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the categorical hinge loss metric between labels and predictions. */ +/** + * A Metric that computes the categorical hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ public class CategoricalHinge extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class CategoricalHinge extends Mean * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CategoricalHinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 4a5214aea8d..9ceccf7fc13 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -20,32 +20,33 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the cosine similarity metric between labels and predictions. */ -// TODO: this is weird, the metric is called CosineSimilarity in Keras, -// but it calls Metrics.cosineProximity instead of Losses.cosineSimilarity. -// The metric is calculating the Euclidean distance using L2 norms, while the loss -// is using the dot product proportional to the product of their magnitudes. -// While the 2 concepts are similar, they are different. -// Should we rename this metric to CosineProximity? +/** + * A metric that computes the cosine similarity metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class CosineSimilarity extends MeanMetricWrapper implements LossInterface { public static final int DEFAULT_AXIS = -1; private final int[] axis; /** - * Creates a CosineSimilarity metric with a default axis, {@link #DEFAULT_AXIS} + * Creates a metric that computes the cosine similarity metric between labels and predictions with + * a default axis, {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_AXIS, seed, type); } /** - * Creates a CosineSimilarity metric + * Creates a metric that computes the cosine similarity metric between labels and predictions. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index d655f8d8237..b276f0b9426 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the hinge loss metric between labels and predictions. */ +/** + * A metric that computes the hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class Hinge extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class Hinge extends MeanMetricWrapp * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public Hinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 3f31383381a..a3cbc6f16e6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,7 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes Computes Kullback-Leibler divergence loss metric between labels and predictions. */ +/** + * A metric that computes the Kullback-Leibler divergence loss metric between labels and + * predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class KLDivergence extends MeanMetricWrapper implements LossInterface { @@ -32,6 +38,7 @@ public class KLDivergence extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public KLDivergence(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 7d4b8a9fad7..d6fe903f5a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -22,8 +22,11 @@ import org.tensorflow.types.family.TNumber; /** - * Computes the logarithm of the hyperbolic cosine of the prediction error metric between labels and - * predictions. + * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric + * between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. */ public class LogCoshError extends MeanMetricWrapper implements LossInterface { @@ -35,6 +38,7 @@ public class LogCoshError extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public LogCoshError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index c68a70902a7..de1f5a5629e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -19,7 +19,7 @@ import org.tensorflow.types.family.TNumber; /** - * Represents a Metric that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } + * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * * @param The data type for the metric values * @param The data type for the metric result @@ -33,7 +33,7 @@ public class Mean extends Reduce { * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the result. + * @param type the type for the variables and result */ protected Mean(Ops tf, String name, long seed, Class type) { super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 6b29c72fe82..79da80ef191 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanAbsoluteError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanAbsoluteError extends Mea * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 6209245d881..558c194074f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanAbsolutePercentageError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanAbsolutePercentageError * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index ce30e378e8d..10704d14bd4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanSquaredError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanSquaredError extends Mean * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanSquaredError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 9baeac2f320..585fc312e5a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanSquaredLogarithmicError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanSquaredLogarithmicError * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 28a2ae0fa94..89e5436ed0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -42,19 +42,6 @@ public abstract class Metric { /** The random number generator seed value */ private final long seed; - // TODO: how to handle variables across new ExecutionEnvironments. - // Metrics may be instantiated multiple times using the same variables, - // These variables become stale when a new ExecutionEnvironment is created - // (most commonly seen in Unit Tests), so the question is how to best handle this. - // Option 1, which is used here is to map the variables against an instance of - // an ExecutionEnvironment in a WeakHashMap, when a new ExecutionEnvironment is presented, the - // new - // variables are mapped to it. A WeakHashMap is used to throw away the old ExecutionEnvironment - // mappings, when the old ExecutionEnvironment is finalized. - // Option 2, keep an instance of the newly presented ExecutionEnvironment and if it changes, - // clear the variable maps. - // My guess is that in a non-unit test environment, only one ExecutionEnvironment will be used, - // I welcome thoughts on this. /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ private final String name; @@ -185,7 +172,6 @@ public final Operand callOnce( */ protected void addVariable( String varName, Variable variable, Initializer initializer) { - // TODO option 2 would be to keep track of tf.scope().env() and if it changes, clear to old Map. Map> variables = variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed, variable.type())); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index c3c44ef6134..8a8ddf3694c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -16,14 +16,12 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -/** Built-in metrics functions. */ +/** Helper class with built-in metrics functions. */ public class Metrics { public static final float L2_NORM_EPSILON = 1e-12f; @@ -60,54 +58,6 @@ public static Operand topKCategoricalA predictions.type()); } - /** - * Computes how often integer targets are in the top K predictions. - * - *

Standalone usage: - * - *

-   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
-   *     Operand<TFloat32> predictions = tf.constant(new float[][]
-   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
-   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
-   *                                    labels, predictions, 3)
-   *     //m.shape().toString == "[2]"
-   * 
- * - * @param tf the TensorFlow Ops. - * @param labels the ground truth values. - * @param predictions The prediction values. - * @param k Number of top elements to look at for computing accuracy. - * @param the data type for the predictions and results - * @param the data type ofr the labels. - * @return the Operand for the Sparse top K categorical accuracy value. - */ - @SuppressWarnings("unchecked") - public static Operand sparseTopKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, int k) { - Operand tLabels; - if (labels.type() != predictions.type()) - tLabels = CastHelper.cast(tf, labels, predictions.type()); - else tLabels = (Operand) labels; - - int predictionsRank = predictions.shape().numDimensions(); - int labelsRank = tLabels.shape().numDimensions(); - - Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); - if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { - if (predictionsRank > 2) { - castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); - } - if (labelsRank > 1) { - tLabels = tf.shape.flatten(tLabels); - } - } - return CastHelper.cast( - tf, - tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), - predictions.type()); - } - /** * Computes the cosine similarity between labels and predictions. * @@ -152,10 +102,7 @@ public static Operand cosineProximity( * @param The data type for x. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python, does it belong here? - - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { return l2Normalize(tf, x, axes, L2_NORM_EPSILON); } @@ -179,7 +126,6 @@ public static Operand l2Normalize( * @param The data type for the values. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python, does it belong here? public static Operand l2Normalize( Ops tf, Operand x, int[] axes, float epsilon) { Operand squareSum = @@ -189,5 +135,4 @@ public static Operand l2Normalize( tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); return tf.math.mul(x, y); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index f5730b07f42..07ab129eb08 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the poisson loss metric between labels and predictions. */ +/** + * A metric that computes the poisson loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class Poisson extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class Poisson extends MeanMetricWra * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public Poisson(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 403e11af8c0..c2f916217e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,7 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the sparse categorical cross-entropy loss between true labels and predicted labels. */ +/** + * A metric that computes the sparse categorical cross-entropy loss between true labels and + * predicted labels. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class SparseCategoricalCrossentropy extends MeanMetricWrapper implements LossInterface { @@ -37,6 +43,7 @@ public class SparseCategoricalCrossentropy * @param axes The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public SparseCategoricalCrossentropy( Ops tf, String name, boolean fromLogits, int axes, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 7ce8091f2a0..d8c7aa097fe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the squared hinge loss metric between labels and predictions. */ +/** + * A metric that computes the squared hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class SquaredHinge extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class SquaredHinge extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public SquaredHinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 77566e5c400..5894b24c4cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -25,8 +25,8 @@ import java.util.List; /** - * Bridges a stateless loss function with the {@link Mean} metric using a reduction of {@link - * MetricReduction#WEIGHTED_MEAN}. + * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of + * {@link MetricReduction#WEIGHTED_MEAN}. * *

The loss function calculates the loss between the labels and predictions * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the @@ -47,6 +47,7 @@ public class MeanMetricWrapper extends Mea * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); @@ -95,9 +96,9 @@ public List updateStateList( Operand tLabels = CastHelper.cast(getTF(), labels, getType()); Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); - Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList(CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList( + CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index cb5e987b4cf..786d5db1261 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -30,7 +30,6 @@ * * @param the data type of the variable */ -// TODO handle distributed variables with VariableAggregation and VariableSynchronization public class MetricVariable { private final Variable variable; private final Initializer initializer; @@ -44,10 +43,12 @@ public class MetricVariable { * @param variable the variable * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variable */ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { this(tf, variable, null, seed, type); } + /** * Creates a Metric Variable * @@ -59,6 +60,7 @@ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variable */ @SuppressWarnings("unchecked") public MetricVariable( From b48c09a2541f4c5c3eb963ebea441ad6978322a2 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 3 Jan 2021 12:07:45 -0500 Subject: [PATCH 05/56] Javadoc fixes --- .../tensorflow/framework/metrics/CategoricalCrossentropy.java | 1 + .../org/tensorflow/framework/metrics/CosineSimilarity.java | 2 ++ .../main/java/org/tensorflow/framework/metrics/Metric.java | 2 ++ .../main/java/org/tensorflow/framework/metrics/Metrics.java | 4 ++-- .../java/org/tensorflow/framework/metrics/impl/Reduce.java | 4 +++- 5 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 6bfd471401b..72e15f1b22b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -79,6 +79,7 @@ public CategoricalCrossentropy( * channels_first. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CategoricalCrossentropy( Ops tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 9ceccf7fc13..5bd0c53b416 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -53,6 +53,7 @@ public CosineSimilarity(Ops tf, String name, long seed, Class type) { * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { this(tf, name, new int[] {axis}, seed, type); @@ -65,6 +66,7 @@ public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 89e5436ed0a..378d026e69c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -164,11 +164,13 @@ public final Operand callOnce( /** * Adds a variable to collect metric values * + * @param varName the name for the variable * @param variable the variable * @param initializer the initializer for the variable, if null, then the default for floating * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} + * @param the date type for the variable */ protected void addVariable( String varName, Variable variable, Initializer initializer) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 8a8ddf3694c..e31cb54a4d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -121,8 +121,8 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param tf The TensorFlow ops * @param x The operand to normalize * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if - * `norm < sqrt(epsilon)`. + * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the divisor if + * norm < sqrt(epsilon). * @param The data type for the values. * @return the normalized values of x. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index d3b7caa54cc..c8499bc1599 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -223,7 +223,9 @@ public Variable getCount() { return count; } - /** Gets the type for the variables */ + /** Gets the type for the variables + * @return the type for the variables + */ public Class getType() { return type; } From c4c06de27c32c924c22d8029bc58378afccdaab6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 5 Jan 2021 12:32:34 -0500 Subject: [PATCH 06/56] Change LossInterface to LossMetric. Fix JavaDoc, modify one line code block to include braces. --- .../framework/metrics/BinaryCrossentropy.java | 10 +++++----- .../framework/metrics/CategoricalCrossentropy.java | 8 ++++---- .../tensorflow/framework/metrics/CategoricalHinge.java | 4 ++-- .../tensorflow/framework/metrics/CosineSimilarity.java | 4 ++-- .../java/org/tensorflow/framework/metrics/Hinge.java | 4 ++-- .../org/tensorflow/framework/metrics/KLDivergence.java | 4 ++-- .../org/tensorflow/framework/metrics/LogCoshError.java | 4 ++-- .../framework/metrics/MeanAbsoluteError.java | 4 ++-- .../framework/metrics/MeanAbsolutePercentageError.java | 4 ++-- .../tensorflow/framework/metrics/MeanSquaredError.java | 4 ++-- .../framework/metrics/MeanSquaredLogarithmicError.java | 4 ++-- .../java/org/tensorflow/framework/metrics/Metric.java | 10 +++++----- .../java/org/tensorflow/framework/metrics/Metrics.java | 6 +----- .../java/org/tensorflow/framework/metrics/Poisson.java | 4 ++-- .../metrics/SparseCategoricalCrossentropy.java | 4 ++-- .../org/tensorflow/framework/metrics/SquaredHinge.java | 4 ++-- .../impl/{LossInterface.java => LossMetric.java} | 2 +- .../framework/metrics/impl/MeanMetricWrapper.java | 9 +++++---- .../framework/metrics/impl/MetricVariable.java | 5 +++-- .../framework/metrics/impl/MetricsHelper.java | 10 +++++----- .../org/tensorflow/framework/metrics/impl/Reduce.java | 5 +++-- 21 files changed, 56 insertions(+), 57 deletions(-) rename tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/{LossInterface.java => LossMetric.java} (95%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 2372293d0d3..c339b977007 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -24,11 +24,14 @@ /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * + *

This is the crossentropy metric class to be used when there are only two label classes (0 and + * 1). + * * @param the data type for the predictions. * @param The data type for the metric result */ public class BinaryCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -36,9 +39,6 @@ public class BinaryCrossentropy /** * Creates a BinaryCrossentropy metric * - *

This is the crossentropy metric class to be used when there are only two label classes (0 - * and 1). - * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 72e15f1b22b..7b8cf0054a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -34,14 +34,14 @@ * @param The data type for the metric result */ public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; private final int axis; /** - * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the * labels and predictions. * *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. @@ -63,7 +63,7 @@ public CategoricalCrossentropy( } /** - * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the * labels and predictions. * * @param tf the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 21f19d88ade..2741a36edb6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result */ public class CategoricalHinge extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a CategoricalHinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 5bd0c53b416..458de092bec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,7 +15,7 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -27,7 +27,7 @@ * @param The data type for the metric result. */ public class CosineSimilarity extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index b276f0b9426..baf9ad8ab7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class Hinge extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Hinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index a3cbc6f16e6..efcbbcbb7f0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -29,7 +29,7 @@ * @param The data type for the metric result. */ public class KLDivergence extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a KLDivergence metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index d6fe903f5a1..3df8505d54b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -29,7 +29,7 @@ * @param The data type for the metric result. */ public class LogCoshError extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a LogCoshError metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 79da80ef191..e27676932ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanAbsoluteError extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 558c194074f..84fa9b627b2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 10704d14bd4..c7edd6ebe93 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanSquaredError extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 585fc312e5a..199b6e0e114 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 378d026e69c..c816b1a98d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -65,8 +65,9 @@ protected Metric(Ops tf, long seed) { * will always produce the same random tensor for a given shape and data type. */ protected Metric(Ops tf, String name, long seed) { - if (!tf.scope().env().isGraph()) + if (!tf.scope().env().isGraph()) { throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); + } this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); this.tf = tf.withSubScope(this.name); @@ -81,7 +82,7 @@ protected Metric(Ops tf, String name, long seed) { * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked","unused"}) + @SuppressWarnings({"unchecked", "unused"}) public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -97,7 +98,7 @@ public List updateStateList(Operand values, Operand sampleWeights) { * @param the data type for the sample weights * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked","unused"}) + @SuppressWarnings({"unchecked", "unused"}) public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { return Collections.EMPTY_LIST; @@ -154,8 +155,7 @@ public Operand result() { * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return result(ltf); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index e31cb54a4d1..b8e79efa450 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -69,13 +69,9 @@ public static Operand topKCategoricalA * @param the data type for the predictions and result * @return Cosine similarity value. */ - @SuppressWarnings("unchecked") public static Operand cosineProximity( Ops tf, Operand labels, Operand predictions, int[] axis) { - Operand labelsNorm; - if (labels.type() != predictions.type()) - labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - else labelsNorm = (Operand) labels; + Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); labelsNorm = l2Normalize(tf, labelsNorm, axis); Operand predictionsNorm = l2Normalize(tf, predictions, axis); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 07ab129eb08..75a2031fbb5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class Poisson extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Poisson metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index c2f916217e4..3fde8b2ecf6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -29,7 +29,7 @@ * @param The data type for the metric result. */ public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final int axes; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index d8c7aa097fe..430dbbcc229 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class SquaredHinge extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a SquaredHinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java similarity index 95% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index aadc211c3c4..b7b87d313aa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -22,7 +22,7 @@ * * @param The data type of the predictions. */ -public interface LossInterface { +public interface LossMetric { /** * Calculates the weighted loss between labels and predictions diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 5894b24c4cd..cd17e2a9de4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -38,7 +38,7 @@ public class MeanMetricWrapper extends Mean { /** The loss function interface */ - protected LossInterface loss; + protected LossMetric loss; /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN} @@ -58,7 +58,7 @@ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { * * @return the loss function. */ - public LossInterface getLoss() { + public LossMetric getLoss() { return loss; } @@ -67,7 +67,7 @@ public LossInterface getLoss() { * * @param loss the loss function. */ - public void setLoss(LossInterface loss) { + protected void setLoss(LossMetric loss) { this.loss = loss; } @@ -90,8 +90,9 @@ public void setLoss(LossInterface loss) { */ public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { - if (labels == null || predictions == null) + if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); + } Operand tLabels = CastHelper.cast(getTF(), labels, getType()); Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index 786d5db1261..c5c5dbb2ab2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -61,6 +61,7 @@ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variable + * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer is null */ @SuppressWarnings("unchecked") public MetricVariable( @@ -78,8 +79,8 @@ public MetricVariable( } else { throw new IllegalArgumentException( String.format( - "An initializer for variable %s of type %s is required", - variable.toString(), type.getSimpleName())); + "Type %s is not a supported for metric variables", + type.getSimpleName())); } } else { this.initializer = initializer; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 042badbb615..9699ccd323c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -42,16 +42,16 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeight can be broadcast to values + * Asserts that the sampleWeights can be broadcast to values * * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation raising InvalidArgumentError if sampleWeight - * has incorrect shape. no_op if static checks determine - * sampleWeight has correct shape. + * @return Operation with control dependencies to ensure sampleWeight + * can be broadcast to values * @param the type of Operand - * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. + * @throws IllegalArgumentException If static checks determine sampleWeights has an + * incorrect shape that prohibit broadcasting to to values */ @SuppressWarnings("unchecked") public static Op broadcastWeights( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index c8499bc1599..3ec6540779c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -112,7 +112,9 @@ private void setupVars() { @Override public List updateStateList(Operand values, Operand sampleWeights) { - if (values == null) throw new IllegalArgumentException("values is required."); + if (values == null) { + throw new IllegalArgumentException("values is required."); + } List updateOperations = new ArrayList<>(); // cast everything to match the variables Operand lSampleWeights = null; @@ -124,7 +126,6 @@ public List updateStateList(Operand values, Operand sampleWeights) { LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); lValues = tuple.getTarget(); lSampleWeights = tuple.getSampleWeights(); - // lSampleWeights = WeightsBroadcastOps.broadcastWeights(getTF(), lSampleWeights, lValues); try { Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); From 2d4c17b85eb7ba85dfa6a6018171cd26793b00f4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 7 Jan 2021 17:15:30 -0500 Subject: [PATCH 07/56] Removed hashmap for variables, they are not needed as the variables only live within a single instance of a Metric. --- .../tensorflow/framework/metrics/Metric.java | 121 +++--------------- .../framework/metrics/impl/Reduce.java | 37 +++--- 2 files changed, 37 insertions(+), 121 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index c816b1a98d0..a6f2cf0f26d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -14,17 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics; -import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; -import org.tensorflow.framework.initializers.Initializer; -import org.tensorflow.framework.metrics.impl.MetricVariable; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.*; -import java.util.stream.Collectors; +import java.util.Collections; +import java.util.List; /** * Base class for Metrics @@ -34,12 +30,8 @@ */ public abstract class Metric { - /** variables are stored by ExecutionEnvironment, and then by an identifier name */ - protected static Map>> - variableMap = new WeakHashMap<>(); /** The TensorFlow Ops */ private final Ops tf; - /** The random number generator seed value */ private final long seed; /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ @@ -70,7 +62,7 @@ protected Metric(Ops tf, String name, long seed) { } this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); - this.tf = tf.withSubScope(this.name); + this.tf = tf.withName(this.getClass().getSimpleName()); } /** @@ -139,6 +131,13 @@ public final Op updateState( */ public abstract Operand result(Ops tf); + /** + * Resets any state variables to their initial values + * + * @return the control operation for doing the reset + */ + public abstract Op resetStates(); + /** * Gets the current result of the metric using the metric's {@link #getTF()} * @@ -161,36 +160,6 @@ public final Operand callOnce(Operand values, Operand sampleWeights) { return result(ltf); } - /** - * Adds a variable to collect metric values - * - * @param varName the name for the variable - * @param variable the variable - * @param initializer the initializer for the variable, if null, then the default for floating - * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution - * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for - * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} - * @param the date type for the variable - */ - protected void addVariable( - String varName, Variable variable, Initializer initializer) { - Map> variables = - variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); - variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed, variable.type())); - } - - /** - * Gets the list of added variables - * - * @return the list of added variables - */ - public List> getVariables() { - List> result = new ArrayList<>(); - Map> variables = variableMap.get(tf.scope().env()); - if (variables != null) variables.values().forEach(mv -> result.add(mv.getVariable())); - return result; - } - /** * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. * @@ -201,71 +170,6 @@ protected String getVariableName(String varName) { return String.format("%s_%s", this.name, varName); } - /** - * Gets an Operation that initializes the variables. - * - * @param subScopeName the sub scope name - * @return the Operation used to initialize the variables. - */ - public Op initialize(String subScopeName) { - - List initializeOperations = initializeVarsList(subScopeName); - return tf.withControlDependencies(initializeOperations).noOp(); - } - - /** - * Gets the list of Operations that initializes the variables - * - * @param subScopeName the sub scope name - * @return the list of Operations that initializes the variables - */ - @SuppressWarnings("unchecked") - private List initializeVarsList(String subScopeName) { - Map> variables = variableMap.get(tf.scope().env()); - if (variables != null) - return variables.values().stream() - .map(metricVariable -> variableAssign(subScopeName, metricVariable)) - .collect(Collectors.toList()); - else return Collections.EMPTY_LIST; - } - - /** - * Resets all variables to their initial state - * - * @return An Operation that sets all variables to their initial state - */ - public Op resetStates() { - return initialize("resetStates"); - } - - /** - * Assigns a value to a Variable - * - *

This assumes the variable has already been initialized - * - * @param subScopeName the subscope for creating the variable - * @param mv the metric value used to assign the initializer to the variable. - * @return the variable add operation with necessary control dependencies - */ - private Operand variableAssign( - String subScopeName, MetricVariable mv) { - return tf.withSubScope(subScopeName).assign(mv.getVariable(), mv.initialize()); - } - - /** - * Gets a stored variable by name, Variables are cached first by the TensorFlow Environment, then - * by a variable name. - * - * @param varName the name assigned to the variable - * @return the variable, or null if the variable is not found. - */ - public Variable getVariable(String varName) { - Map> variables = variableMap.get(tf.scope().env()); - if (variables == null) return null; - MetricVariable mv = variables.get(varName); - return mv != null ? mv.getVariable() : null; - } - /** * Gets the TensorFlow Ops * @@ -283,4 +187,9 @@ public Ops getTF() { public String getName() { return name; } + + /** The random number generator seed value */ + public long getSeed() { + return seed; + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 3ec6540779c..2c387cc152e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; @@ -50,7 +49,6 @@ public abstract class Reduce extends Metri /** the variable that holds the count of the metric values */ protected Variable count; - protected boolean initialized; /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} @@ -81,25 +79,33 @@ protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Clas this.type = type; setupVars(); } - /** initialize the Variables */ - @SuppressWarnings("unchecked") + /** Initializes the Variables */ private void setupVars() { - Zeros fZeros = new Zeros<>(getTF()); - total = (Variable) getVariable(totalName); if (total == null) { - total = getTF().withSubScope(totalName).variable(Shape.scalar(), type); - addVariable(totalName, total, fZeros); + total = getTF().withName(totalName).variable(Shape.scalar(), type); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { - count = (Variable) getVariable(countName); if (count == null) { - count = getTF().withSubScope(countName).variable(Shape.scalar(), type); - addVariable(countName, count, fZeros); + count = getTF().withName(countName).variable(Shape.scalar(), type); } } } + /** {@inheritDoc} */ + public Op resetStates() { + List controls = new ArrayList<>(); + if (total != null) { + controls.add( + getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + } + if (count != null) { + controls.add( + getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + } + return getTF().withControlDependencies(controls).noOp(); + } + /** * Updates the metric variables based on the inputs. At least one input arg required for * values, an optional additional input for the sampleWeights @@ -110,7 +116,7 @@ private void setupVars() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); @@ -136,7 +142,6 @@ public List updateStateList(Operand values, Operand sampleWeights) { .math .mul(lValues, lSampleWeights); } catch (IllegalArgumentException ex) { - System.out.println("Reduce: Fall back from broadcast"); // reduce the values down to the rank of the samples int nDim = lValues.shape().numDimensions(); int wDim = lSampleWeights.shape().numDimensions(); @@ -224,8 +229,10 @@ public Variable getCount() { return count; } - /** Gets the type for the variables - * @return the type for the variables + /** + * Gets the type for the variables + * + * @return the type for the variables */ public Class getType() { return type; From fbb12f4f6a242bf4661d4c3114513c1fbf73763c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 7 Jan 2021 17:16:31 -0500 Subject: [PATCH 08/56] reformat code --- .../src/main/java/org/tensorflow/framework/metrics/Metric.java | 1 + .../main/java/org/tensorflow/framework/metrics/impl/Reduce.java | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index a6f2cf0f26d..9efb3dde20a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -32,6 +32,7 @@ public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; + private final long seed; /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 2c387cc152e..f304ad04cb4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -49,7 +49,6 @@ public abstract class Reduce extends Metri /** the variable that holds the count of the metric values */ protected Variable count; - /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} * From 30196af35bda8189346fca3115a37d4bacda6875 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:02:27 -0500 Subject: [PATCH 09/56] Add tests for assertBroadcastable --- .../framework/metrics/impl/MetricsHelper.java | 49 ++- .../metrics/impl/WeightBroadcastTest.java | 335 ++++++++++++++++++ 2 files changed, 365 insertions(+), 19 deletions(-) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 9699ccd323c..5ecc06a388f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -18,6 +18,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.SetDiff1d; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -30,7 +31,6 @@ import java.util.List; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; /** * These are helper methods for Metrics and will be module private when Java modularity is applied @@ -42,7 +42,12 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to values + * Asserts that the sampleWeights can be broadcast to the same shape as values + * + * + *

In losses and metrics, limited weight broadcasting is supported. Weights be either scalar, + * or the same rank as the target values, with each dimension either 1, or the same as the + * corresponding values dimension. * * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. @@ -54,9 +59,10 @@ public class MetricsHelper { * incorrect shape that prohibit broadcasting to to values */ @SuppressWarnings("unchecked") - public static Op broadcastWeights( + public static Op assertBroadcastable( Ops tf, Operand sampleWeights, Operand values) { + // try static check for exact match Operand weightsShape = tf.shape(sampleWeights); Operand weightsRank = tf.rank(sampleWeights); Shape weightsShapeStatic = sampleWeights.shape(); @@ -67,9 +73,9 @@ public static Op broadcastWeights( Shape valuesShapeStatic = values.shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); - if (weightsRankStatic != -1 && valuesRankStatic != -1) { + if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { if (weightsRankStatic == 0) { - return tf.withSubScope("static_scalar_check_success") + return tf.withSubScope("staticScalarCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } @@ -85,7 +91,7 @@ public static Op broadcastWeights( } for (int i = 0; i < valuesRankStatic; i++) { - if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", @@ -95,12 +101,12 @@ public static Op broadcastWeights( weightsShapeStatic.toString())); } } - return tf.withSubScope("static_dims_check_success") + return tf.withSubScope("staticDimsCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } // Dynamic checks. - Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( tf.constant(ASSERT_BROADCAST_ERROR_PREFIX), @@ -108,14 +114,13 @@ public static Op broadcastWeights( weightsShape, tf.constant("values.shape="), valuesShape, - tf.constant("is_scalar="), - is_scalar); + tf.constant("isScalar="), + isScalar); + + Operand validNonsclar = + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape); - Operand isValidShape = - tf.select( - is_scalar, - is_scalar, - hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); + Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); } @@ -137,7 +142,7 @@ private static Operand hasValidNonscalarShape( Operand weightsShape, Operand valuesRank, Operand valuesShape) { - tf = tf.withSubScope("has_valid_nonscalar_shape"); + tf = tf.withSubScope("hasValidNonscalarShape"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); } @@ -153,9 +158,15 @@ private static Operand hasValidNonscalarShape( */ private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { - tf = tf.withSubScope("has_invalid_dims"); - Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); - return tf.math.equal(cast(tf, tf.constant(0), diff.asOutput().type()), diff); + tf = tf.withSubScope("hasValidDims"); + Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); + Operand validDims = + tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + SetDiff1d invalidDimsDiff = + tf.setDiff1d(tf.shape.flatten(valuesShape2d), tf.shape.flatten(validDims)); + Operand invalidDims = invalidDimsDiff.out(); + Operand numInvalidDims = tf.size(invalidDims); + return tf.math.equal(tf.constant(0), numInvalidDims); } // alias for mean diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java new file mode 100644 index 00000000000..c89cff93dc2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java @@ -0,0 +1,335 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class WeightBroadcastTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + private void testValid( + TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { + + Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); + testSession.run(staticOp); + + // dynamic test + Operand weightsPlaceholder = tf.placeholder(type); + Operand valuesPlaceholder = tf.placeholder(type); + + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); + try (Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1)) { + + Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); + + testSession + .getGraphSession() + .runner() + .feed(weightsPlaceholder, weightsTensor) + .feed(valuesPlaceholder, valuesTensor) + .addTarget(dynamicOp) + .run(); + } + } + + @Test + public void testValidScalar() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new float[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(5f); + testValid(testSession, tf, weights, values, TFloat32.class); + } + } + + @Test + public void test1x1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new double[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new double[][][] {{{5}}}); + testValid(testSession, tf, weights, values, TFloat64.class); + } + } + + @Test + public void test1x1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new long[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + testValid(testSession, tf, weights, values, TInt64.class); + } + } + + @Test + public void test1xNx1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void test1xNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNx1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNx1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = + tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNxNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = + tf.constant( + new int[][][] { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testInvalid1x1() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][] {{5}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidOnesExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][][] {{{{5}}}} ); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + + Operand weights = tf.constant(new int[][][][] { + {{ { 5},{ 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} }, + {{ { 2}, {17}, {11}, { 3}}, {{ 2}, {17}, {11}, { 3}} }, + {{ { 5}, { 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} } + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][][] {{ + { { 5, 7, 11, 3}, { 2, 12, 7, 5} }, + { { 2, 17, 11, 3}, { 2, 17, 11, 3} }, + { { 5, 7, 11, 3}, { 2, 12, 7, 5} } + }}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } +} From 12485814aca42815bf2816e81265491cf675c92a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:41:00 -0500 Subject: [PATCH 10/56] Change type to resultType --- .../tensorflow/framework/metrics/impl/MeanMetricWrapper.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index cd17e2a9de4..173167c5370 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -94,8 +94,8 @@ public List updateStateList( throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); + Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); + Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); From 96bd55b2e0b17554bf895c47a89f3771820461a3 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:42:18 -0500 Subject: [PATCH 11/56] Added V data type for sampleWeights so that it is not forced to be the same type as the return or internal variables, --- .../tensorflow/framework/metrics/Metric.java | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 9efb3dde20a..123abae61d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -76,7 +76,7 @@ protected Metric(Ops tf, String name, long seed) { * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -88,7 +88,7 @@ public List updateStateList(Operand values, Operand sampleWeights) { * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the sample weights + * @param the data type for the labels * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) @@ -104,7 +104,7 @@ public List updateStateList( * @param sampleWeights sample weights to be applied to values, may be null. * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -115,7 +115,7 @@ public final Op updateState(Operand values, Operand sampleWeights) { * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the sample weights + * @param the data type for the labels * @return the Operation to update the metric state */ public final Op updateState( @@ -127,10 +127,9 @@ public final Op updateState( /** * Gets the current result of the metric * - * @param tf the TensorFlow Ops used to create the result * @return the result, possibly with control dependencies */ - public abstract Operand result(Ops tf); + public abstract Operand result(); /** * Resets any state variables to their initial values @@ -139,14 +138,6 @@ public final Op updateState( */ public abstract Op resetStates(); - /** - * Gets the current result of the metric using the metric's {@link #getTF()} - * - * @return the result, possibly with control dependencies - */ - public Operand result() { - return result(this.tf); - } /** * Calls update state once, followed by a call to get the result @@ -158,7 +149,7 @@ public Operand result() { public final Operand callOnce(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); - return result(ltf); + return ltf.identity(result()); } /** From d706f55203d7cdebd030b7875e64723fa834d790 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:42:49 -0500 Subject: [PATCH 12/56] change 'type' to 'resultType' --- .../framework/metrics/impl/Reduce.java | 62 ++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index f304ad04cb4..fb8e39f3f1f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -43,7 +43,7 @@ public abstract class Reduce extends Metri private final String totalName; private final String countName; - private final Class type; + private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; /** the variable that holds the count of the metric values */ @@ -56,10 +56,10 @@ public abstract class Reduce extends Metri * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result + * @param resultType the type for the variables and result */ - protected Reduce(Ops tf, String name, long seed, Class type) { - this(tf, name, MetricReduction.SUM, seed, type); + protected Reduce(Ops tf, String name, long seed, Class resultType) { + this(tf, name, MetricReduction.SUM, seed, resultType); } /** @@ -68,25 +68,25 @@ protected Reduce(Ops tf, String name, long seed, Class type) { * @param reduction The type of metric reduction to apply * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result + * @param resultType the type for the variables and result */ - protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class type) { + protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class resultType) { super(tf, name, seed); this.reduction = reduction; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); - this.type = type; + this.resultType = resultType; setupVars(); } /** Initializes the Variables */ private void setupVars() { if (total == null) { - total = getTF().withName(totalName).variable(Shape.scalar(), type); + total = getTF().withName(totalName).variable(Shape.scalar(), resultType); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { if (count == null) { - count = getTF().withName(countName).variable(Shape.scalar(), type); + count = getTF().withName(countName).variable(Shape.scalar(), resultType); } } } @@ -115,7 +115,7 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); @@ -133,7 +133,7 @@ public List updateStateList(Operand values, Operand sampleWeights) { lSampleWeights = tuple.getSampleWeights(); try { - Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + Op broadcastWeightsCheck = MetricsHelper.assertBroadcastable(getTF(), lSampleWeights, lValues); lValues = getTF() .withSubScope("broadcastWeightsCheck") @@ -141,16 +141,20 @@ public List updateStateList(Operand values, Operand sampleWeights) { .math .mul(lValues, lSampleWeights); } catch (IllegalArgumentException ex) { - // reduce the values down to the rank of the samples - int nDim = lValues.shape().numDimensions(); - int wDim = lSampleWeights.shape().numDimensions(); - int numAxes = nDim - wDim; - int[] axes = new int[numAxes]; - for (int i = 0; i < numAxes; i++) axes[i] = i + wDim; - if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); - } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + // if we get here we have static shapes with either + // different ranks or different dimension sizes. + // first, reduce the values down to the rank of the samples + int valuesDim = lValues.shape().numDimensions(); + int weightsDim = lSampleWeights.shape().numDimensions(); + int numAxes = Math.min(0, valuesDim - weightsDim); + if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. + int[] axes = new int[numAxes]; + for (int i = 0; i < numAxes; i++) axes[i] = i + weightsDim; + if (reduction == MetricReduction.SUM) { + lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + } else { + lValues = getTF().math.mean(lValues, getTF().constant(axes)); + } } lValues = getTF().math.mul(lValues, lSampleWeights); } @@ -164,18 +168,18 @@ public List updateStateList(Operand values, Operand sampleWeights) { if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); break; case WEIGHTED_MEAN: if (lSampleWeights == null) { - numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); } else { numValues = CastHelper.cast( getTF(), getTF() .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), - type); + resultType); } break; default: @@ -192,16 +196,16 @@ public List updateStateList(Operand values, Operand sampleWeights) { /** {@inheritDoc} */ @Override - public Operand result(Ops rtf) { + public Operand result() { Operand fResult; switch (this.reduction) { case SUM: - fResult = rtf.identity(total); + fResult = getTF().identity(total); break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, type)); + fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( @@ -233,7 +237,7 @@ public Variable getCount() { * * @return the type for the variables */ - public Class getType() { - return type; + public Class getResultType() { + return resultType; } } From 014744210420a3fd2e155cbe6db28c168764239f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:43:43 -0500 Subject: [PATCH 13/56] clean up mean and fix assert assertBroadcastable --- .../framework/metrics/impl/MetricsHelper.java | 60 +++++++++++-------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 5ecc06a388f..eb7c1fbd221 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -23,6 +23,7 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -91,7 +92,8 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) + && weightsShapeStatic.size(i) != 1) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", @@ -152,7 +154,7 @@ private static Operand hasValidNonscalarShape( * * @param tf the TensorFlow Ops * @param weightsShape the operand for the shape of the sample weights - * @param valuesShape the operand for the shape of the sample weights + * @param valuesShape the operand for the shape of the values * @param the data type for the operands * @return a boolean operand to determine if the shapes have valid dimensions or not. */ @@ -163,7 +165,7 @@ private static Operand hasValidDims( Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); SetDiff1d invalidDimsDiff = - tf.setDiff1d(tf.shape.flatten(valuesShape2d), tf.shape.flatten(validDims)); + tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); Operand invalidDims = invalidDimsDiff.out(); Operand numInvalidDims = tf.size(invalidDims); return tf.math.equal(tf.constant(0), numInvalidDims); @@ -178,9 +180,10 @@ private static Operand hasValidDims( * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param the type of the Operand. + * @param the data type for the result * @return the mean of the operand */ - public static Operand mean(Ops tf, Operand x) { + public static Operand mean(Ops tf, Operand x) { return mean(tf, x, null, false); } @@ -190,58 +193,63 @@ public static Operand mean(Ops tf, Operand x) { * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param axis Axes to compute the mean. + * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axis. - * @return the mean of the operand, alongside the specified axis. + * @param the type of the axes. + * @param the data type for the result + * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axis) { - return mean(tf, x, axis, false); + public static Operand mean( + Ops tf, Operand x, Operand axes) { + return mean(tf, x, axes, false); } /** - * Calculate the mean of the operand, along all axis. + * Calculate the mean of the operand, along all axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axis + * false, the rank of the tensor is reduced by 1 for each entry in axes * . If keepdims is true, the reduced dimensions are retained * with length 1. * @param the type of the operand + * @param the data type for the result * @return the mean of elements of x. */ - public static Operand mean(Ops tf, Operand x, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } /** - * Calculate the mean of the operand, alongside the specified axis. + * Calculate the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param axis Axes to compute the mean. + * @param axes Axes to compute the mean. * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is `true`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the * * reduced dimensions are retained with length 1. * @param the data type of the Operand - * @param the data type of the axis + * @param the data type of the axes + * @param the data type for the result * @return the mean of elements of `x`. */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Operand mean( - Ops tf, Operand x, Operand axis, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { // Cannot use generics here because xf may change from TBool to TFloat32 - Operand xf; - if (x.asOutput().type() == TBool.class) { - xf = tf.dtypes.cast(x, TFloat32.class); + Operand xf; + if (x.type().equals(TBool.class)) { + xf = (Operand) tf.dtypes.cast(x, TFloat32.class); } else { - xf = x; + xf = (Operand) x; } - if (axis == null) { - axis = allAxes(tf, xf); + if (axes == null) { + axes = (Operand) allAxes(tf, xf); } - return tf.math.mean(xf, axis, Mean.keepDims(keepDims)); + Operand theMean = tf.math.mean(xf, axes, Mean.keepDims(keepDims)); + return x.type().equals(TBool.class) ? tf.dtypes.cast(theMean, TBool.class) : theMean; } } From c0c127faf60cc1a84a9f5f7b9b2ad71ec0e75166 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:44:07 -0500 Subject: [PATCH 14/56] fix error message --- .../org/tensorflow/framework/metrics/impl/MetricVariable.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index c5c5dbb2ab2..aae5a8f30c4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -79,7 +79,7 @@ public MetricVariable( } else { throw new IllegalArgumentException( String.format( - "Type %s is not a supported for metric variables", + "Type %s is not supported for metric variables", type.getSimpleName())); } } else { From 89ec9ed6305afe6960a4adef8470eb9bed6f518b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:19:07 -0500 Subject: [PATCH 15/56] Change sampleWeights to have its own generic type --- .../org/tensorflow/framework/metrics/Metric.java | 16 ++++++++++------ .../metrics/impl/MeanMetricWrapper.java | 5 +++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 123abae61d7..20151eb1408 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -74,9 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state + * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -89,11 +90,12 @@ public List updateStateList(Operand values, Operand the data type for the labels + * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -102,9 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -116,10 +119,11 @@ public final Op updateState(Operand values, Operand sa * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. * @param the data type for the labels + * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, Operand predictions, Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 173167c5370..98447142da6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -86,10 +86,11 @@ protected void setLoss(LossMetric loss) { * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param the datatype of the predictions + * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } From 1b81c8263e115b74fc8ed3f79ea5550cffaf55ea Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:19:56 -0500 Subject: [PATCH 16/56] Add commment about invalid tests expecting IllegalArgumentExceptions --- .../metrics/impl/WeightBroadcastTest.java | 118 ++++++++++-------- 1 file changed, 65 insertions(+), 53 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java index c89cff93dc2..08e19f82a89 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java @@ -204,8 +204,14 @@ public void testNxNxN() { } } + // Note: For invalid tests, either NotBroadcastableException is thrown for static shapes or + // TFInvalidInvalidException is thrown for dynamic shapes. Both of these extend + // IllegalArgumentException, + // To simply the assertThrows, only IllegalArgumentException is expected. + // The private method, testValid, tests for both static and dynamic shapes. @Test public void testInvalid1x1() { + assertThrows( IllegalArgumentException.class, () -> { @@ -267,69 +273,75 @@ public void testInvalidSuffixMatch() { @Test public void testInvalidOnesExtraDim() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][][] {{{{5}}}} ); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); } @Test public void testInvalidPrefixMatchExtraDim() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); - Operand weights = tf.constant(new int[][][][] { - {{ { 5},{ 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} }, - {{ { 2}, {17}, {11}, { 3}}, {{ 2}, {17}, {11}, { 3}} }, - {{ { 5}, { 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} } - }); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); + Operand weights = + tf.constant( + new int[][][][] { + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); } @Test public void testInvalidSuffixMatchExtraDim() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][][] {{ - { { 5, 7, 11, 3}, { 2, 12, 7, 5} }, - { { 2, 17, 11, 3}, { 2, 17, 11, 3} }, - { { 5, 7, 11, 3}, { 2, 12, 7, 5} } - }}); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = + tf.constant( + new int[][][][] { + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); } } From 54d4ae9efd5511f8bf64707f7deff0cf6152d7d8 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:20:39 -0500 Subject: [PATCH 17/56] Add this exception instead of the more generic IllegalArgumentException when static shapes cannot boradcast. --- .../exceptions/NotBroadcastableException.java | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java new file mode 100644 index 00000000000..73f07b977c2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.exceptions; + +import org.tensorflow.ndarray.Shape; + +/** + * Exception that indicates that static shapes are not able to broadcast among each other during arithmetic operations. + * Static shapes do not have unknown rank or any unknown dimensions {@link Shape#hasUnknownDimension()}. + * The term broadcasting describes how TensorFlow treats arrays with different shapes during arithmetic operations. + * + *

Broadcasting is the process of making arrays to have compatible shapes for arithmetic + * operations. Two shapes are compatible if for each dimension pair they are either equal or one of + * them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing + * dimensions, and works its way forward. + * + * + * @see Numpy Broadcasting + */ +public class NotBroadcastableException extends IllegalArgumentException { + + /** + * Creates a new NotBroadcastableException exception with the specified detail message + * @param message the detail message. + */ + public NotBroadcastableException(String message) { + super(message); + } + + /** + * Creates a new NotBroadcastableException exception with the specified detail message + * @param message the detail message. + * @param cause the cause + */ + public NotBroadcastableException(String message, Throwable cause) { + super(message, cause); + } +} From ca50dfa9349191c90a6f2d010f93123fd413a61d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:22:51 -0500 Subject: [PATCH 18/56] change IllegalArgumentException to NotBroadcastableException. change hasValidNonscalarShape to canBroadcastNonscalarShapes change hasValidNonscalarShape to canBroadcastNonscalarShapes --- .../framework/metrics/impl/MetricsHelper.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index eb7c1fbd221..b25a03b07c9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -56,8 +57,8 @@ public class MetricsHelper { * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values * @param the type of Operand - * @throws IllegalArgumentException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to to values + * @throws NotBroadcastableException If static checks determine sampleWeights has an + * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -81,7 +82,7 @@ public static Op assertBroadcastable( .noOp(); } if (weightsRankStatic != valuesRankStatic) { - throw new IllegalArgumentException( + throw new NotBroadcastableException( String.format( "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, @@ -94,7 +95,7 @@ public static Op assertBroadcastable( for (int i = 0; i < valuesRankStatic; i++) { if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { - throw new IllegalArgumentException( + throw new NotBroadcastableException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, @@ -120,7 +121,7 @@ public static Op assertBroadcastable( isScalar); Operand validNonsclar = - hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape); + canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); @@ -138,7 +139,7 @@ public static Op assertBroadcastable( * @param the data type for the operands * @return a boolean operand to determine if the Shape is scalar or not. */ - private static Operand hasValidNonscalarShape( + private static Operand canBroadcastNonscalarShapes( Ops tf, Operand weightsRank, Operand weightsShape, @@ -146,7 +147,7 @@ private static Operand hasValidNonscalarShape( Operand valuesShape) { tf = tf.withSubScope("hasValidNonscalarShape"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); - return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); + return tf.select(isSameRank, canBroadcastDims(tf, weightsShape, valuesShape), isSameRank); } /** @@ -158,7 +159,7 @@ private static Operand hasValidNonscalarShape( * @param the data type for the operands * @return a boolean operand to determine if the shapes have valid dimensions or not. */ - private static Operand hasValidDims( + private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("hasValidDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); From ca5ad092a68c2cbfc20f459b0dc9b6a7fb6fb231 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:23:58 -0500 Subject: [PATCH 19/56] reformat code --- .../org/tensorflow/framework/metrics/Metric.java | 1 - .../org/tensorflow/framework/metrics/Metrics.java | 4 ++-- .../exceptions/NotBroadcastableException.java | 10 ++++++---- .../framework/metrics/impl/MeanMetricWrapper.java | 2 +- .../framework/metrics/impl/MetricVariable.java | 7 +++---- .../framework/metrics/impl/MetricsHelper.java | 7 +++---- .../tensorflow/framework/metrics/impl/Reduce.java | 14 +++++++++----- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 20151eb1408..57e332a0843 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -142,7 +142,6 @@ public final Op updateState( */ public abstract Op resetStates(); - /** * Calls update state once, followed by a call to get the result * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index b8e79efa450..e2cd5e368c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -117,8 +117,8 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param tf The TensorFlow ops * @param x The operand to normalize * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the divisor if - * norm < sqrt(epsilon). + * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the + * divisor if norm < sqrt(epsilon). * @param The data type for the values. * @return the normalized values of x. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java index 73f07b977c2..66640e72f50 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java @@ -17,22 +17,23 @@ import org.tensorflow.ndarray.Shape; /** - * Exception that indicates that static shapes are not able to broadcast among each other during arithmetic operations. - * Static shapes do not have unknown rank or any unknown dimensions {@link Shape#hasUnknownDimension()}. - * The term broadcasting describes how TensorFlow treats arrays with different shapes during arithmetic operations. + * Exception that indicates that static shapes are not able to broadcast among each other during + * arithmetic operations. Static shapes do not have unknown rank or any unknown dimensions {@link + * Shape#hasUnknownDimension()}. The term broadcasting describes how TensorFlow treats arrays with + * different shapes during arithmetic operations. * *

Broadcasting is the process of making arrays to have compatible shapes for arithmetic * operations. Two shapes are compatible if for each dimension pair they are either equal or one of * them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing * dimensions, and works its way forward. * - * * @see Numpy Broadcasting */ public class NotBroadcastableException extends IllegalArgumentException { /** * Creates a new NotBroadcastableException exception with the specified detail message + * * @param message the detail message. */ public NotBroadcastableException(String message) { @@ -41,6 +42,7 @@ public NotBroadcastableException(String message) { /** * Creates a new NotBroadcastableException exception with the specified detail message + * * @param message the detail message. * @param cause the cause */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 98447142da6..e2f1345f356 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -89,7 +89,7 @@ protected void setLoss(LossMetric loss) { * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( + public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index aae5a8f30c4..6b208c0d7bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -61,7 +61,8 @@ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variable - * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer is null + * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer + * is null */ @SuppressWarnings("unchecked") public MetricVariable( @@ -78,9 +79,7 @@ public MetricVariable( this.initializer = new Zeros<>(tf); } else { throw new IllegalArgumentException( - String.format( - "Type %s is not supported for metric variables", - type.getSimpleName())); + String.format("Type %s is not supported for metric variables", type.getSimpleName())); } } else { this.initializer = initializer; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index b25a03b07c9..05bfe17a1be 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -58,7 +58,7 @@ public class MetricsHelper { * can be broadcast to values * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -121,7 +121,7 @@ public static Op assertBroadcastable( isScalar); Operand validNonsclar = - canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); + canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); @@ -165,8 +165,7 @@ private static Operand canBroadcastDims( Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); - SetDiff1d invalidDimsDiff = - tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); + SetDiff1d invalidDimsDiff = tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); Operand invalidDims = invalidDimsDiff.out(); Operand numInvalidDims = tf.size(invalidDims); return tf.math.equal(tf.constant(0), numInvalidDims); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index fb8e39f3f1f..6e1795af2eb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -133,7 +133,8 @@ public List updateStateList(Operand values, Operand List updateStateList(Operand values, Operand 0) { // values rank is greater than weights rank, reduce values to weights rank. + if (numAxes + > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsDim; if (reduction == MetricReduction.SUM) { @@ -168,18 +170,20 @@ public List updateStateList(Operand values, Operand Date: Wed, 13 Jan 2021 07:07:26 -0500 Subject: [PATCH 20/56] Fis=x Javadoc move the dynamic shapes and rank down to the dynamic section so they are created needlessly when static Fix if statement to check for unknown size and unknown dimensions --- .../framework/metrics/impl/MetricsHelper.java | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 05bfe17a1be..00af7a6d1af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -47,8 +47,8 @@ public class MetricsHelper { * Asserts that the sampleWeights can be broadcast to the same shape as values * * - *

In losses and metrics, limited weight broadcasting is supported. Weights be either scalar, - * or the same rank as the target values, with each dimension either 1, or the same as the + *

In losses and metrics, limited weight broadcasting is supported. Weights must be either + * scalar, or the same rank as the target values, with each dimension either 1, or the same as the * corresponding values dimension. * * @param tf the TensorFlow Ops @@ -65,17 +65,17 @@ public static Op assertBroadcastable( Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match - Operand weightsShape = tf.shape(sampleWeights); - Operand weightsRank = tf.rank(sampleWeights); + Shape weightsShapeStatic = sampleWeights.shape(); int weightsRankStatic = weightsShapeStatic.numDimensions(); - Operand valuesShape = tf.shape(values); - Operand valuesRank = tf.rank(values); Shape valuesShapeStatic = values.shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); - if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { + // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { + if (!weightsShapeStatic.isUnknown() + && !valuesShapeStatic.isUnknown() + && !weightsShapeStatic.hasUnknownDimension() & !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) @@ -109,6 +109,11 @@ public static Op assertBroadcastable( .noOp(); } // Dynamic checks. + Operand weightsShape = tf.shape(sampleWeights); + Operand weightsRank = tf.rank(sampleWeights); + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( From 85bdde7db97a5cfb0aba4c6d4955a5c4b709ce85 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 17 Jan 2021 13:00:06 -0500 Subject: [PATCH 21/56] Fix Reduce to use boradcastWeights, renamed WeightBroadcastTest to AssertBroadcastableTest and added BroadcastWeightsTest --- .../tensorflow/framework/metrics/Metric.java | 5 +- .../framework/metrics/impl/MetricsHelper.java | 31 +- .../framework/metrics/impl/Reduce.java | 26 +- ...Test.java => AssertBroadcastableTest.java} | 140 ++----- .../metrics/impl/BroadcastWeightsTest.java | 380 ++++++++++++++++++ .../framework/utils/GraphTestSession.java | 12 +- .../framework/utils/TestSession.java | 29 ++ 7 files changed, 498 insertions(+), 125 deletions(-) rename tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/{WeightBroadcastTest.java => AssertBroadcastableTest.java} (68%) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 57e332a0843..bbb2aa73da2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -33,6 +33,7 @@ public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; + /** The seed for random number generation */ private final long seed; /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ @@ -148,8 +149,10 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies + * @param the data type for the sampleWeights. */ - public final Operand callOnce(Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 00af7a6d1af..fbe50151854 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -75,7 +75,8 @@ public static Op assertBroadcastable( // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { if (!weightsShapeStatic.isUnknown() && !valuesShapeStatic.isUnknown() - && !weightsShapeStatic.hasUnknownDimension() & !valuesShapeStatic.hasUnknownDimension()) { + && !weightsShapeStatic.hasUnknownDimension() + && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) @@ -176,6 +177,34 @@ private static Operand canBroadcastDims( return tf.math.equal(tf.constant(0), numInvalidDims); } + /** + * Broadcast `weights` to the same shape as `values`. + * + * @param tf the TensorFlow ops + * @param weights `Tensor` whose shape is broadcastable to `values` + * @param values Tensor` of any shape + * @param the type of Operands + * @return weights broadcast to values shape + */ + public static Operand broadcastWeights( + Ops tf, Operand weights, Operand values) { + + Shape weightsShape = weights.shape(); + Shape valuesShape = values.shape(); + + if (!weightsShape.hasUnknownDimension() + && !valuesShape.hasUnknownDimension() + && weightsShape.isCompatibleWith(valuesShape)) { + return weights; + } + + Ops ctf = + tf.withSubScope("broadcastWeights") + .withControlDependencies( + Collections.singletonList(assertBroadcastable(tf, weights, tf.onesLike(values)))); + return ctf.math.mul(weights, tf.onesLike(values)); + } + // alias for mean /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 6e1795af2eb..771f4804dea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -27,7 +27,6 @@ import org.tensorflow.types.family.TNumber; import java.util.ArrayList; -import java.util.Collections; import java.util.List; /** @@ -132,39 +131,32 @@ public List updateStateList(Operand values, Operand 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; - for (int i = 0; i < numAxes; i++) axes[i] = i + weightsDim; + for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { lValues = getTF().reduceSum(lValues, getTF().constant(axes)); } else { lValues = getTF().math.mean(lValues, getTF().constant(axes)); } } - lValues = getTF().math.mul(lValues, lSampleWeights); } + lValues = getTF().math.mul(lValues, lSampleWeights); } - Operand valueSum = getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), valueSum, total.type())); + getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java similarity index 68% rename from tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 08e19f82a89..af4a89692d1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -30,15 +30,39 @@ import static org.junit.jupiter.api.Assertions.assertThrows; -public class WeightBroadcastTest { +public class AssertBroadcastableTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + int[][][] valueArrayI = + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + long[][][] valueArrayL = + new long[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + float[][][] valueArrayF = + new float[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + double[][][] valueArrayD = + new double[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + private void testValid( TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); - testSession.run(staticOp); // dynamic test Operand weightsPlaceholder = tf.placeholder(type); @@ -66,13 +90,7 @@ public void testValidScalar() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new float[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayF); Operand weights = tf.constant(5f); testValid(testSession, tf, weights, values, TFloat32.class); } @@ -83,13 +101,7 @@ public void test1x1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new double[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayD); Operand weights = tf.constant(new double[][][] {{{5}}}); testValid(testSession, tf, weights, values, TFloat64.class); } @@ -100,13 +112,7 @@ public void test1x1xN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new long[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayL); Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); testValid(testSession, tf, weights, values, TInt64.class); } @@ -117,13 +123,7 @@ public void test1xNx1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -134,13 +134,7 @@ public void test1xNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -151,13 +145,7 @@ public void testNx1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -168,13 +156,7 @@ public void testNx1xN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -186,13 +168,7 @@ public void testNxNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( new int[][][] { @@ -217,13 +193,7 @@ public void testInvalid1x1() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][] {{5}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -237,13 +207,7 @@ public void testInvalidPrefixMatch() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -257,13 +221,7 @@ public void testInvalidSuffixMatch() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -277,13 +235,7 @@ public void testInvalidOnesExtraDim() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][][] {{{{5}}}}); testValid(testSession, tf, weights, values, TInt32.class); } @@ -297,13 +249,7 @@ public void testInvalidPrefixMatchExtraDim() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( @@ -324,13 +270,7 @@ public void testInvalidSuffixMatchExtraDim() { () -> { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( new int[][][][] { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java new file mode 100644 index 00000000000..3322a81fe5b --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -0,0 +1,380 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class BroadcastWeightsTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + int[][][] valueArrayI = + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + long[][][] valueArrayL = + new long[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + float[][][] valueArrayF = + new float[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + double[][][] valueArrayD = + new double[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + + private void testValid( + TestSession testSession, + Ops tf, + Operand weights, + Operand values, + Number[] expected, // flattened array + Class type) { + + Operand staticOp = MetricsHelper.broadcastWeights(tf, weights, values); + if (expected != null) { + testSession.evaluate(expected, staticOp); + } else { + testSession.run(staticOp); + } + + // dynamic test + Operand weightsPlaceholder = tf.placeholder(type); + Operand valuesPlaceholder = tf.placeholder(type); + + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); + try (Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1)) { + + Operand dynamicOp = + MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); + + List result = + testSession + .getGraphSession() + .runner() + .feed(weightsPlaceholder, weightsTensor) + .feed(valuesPlaceholder, valuesTensor) + .fetch(dynamicOp) + .run(); + + if (expected != null) { + if (type.equals(TInt32.class)) { + TInt32 intT = (TInt32) result.get(0); + AtomicInteger i = new AtomicInteger(); + intT.scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); + } else if (type.equals(TInt64.class)) { + TInt64 floatT = (TInt64) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); + } else if (type.equals(TFloat32.class)) { + TFloat32 floatT = (TFloat32) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); + } else if (type.equals(TFloat64.class)) { + TFloat64 doubleT = (TFloat64) result.get(0); + AtomicInteger i = new AtomicInteger(); + doubleT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + } + } + } + } + + @Test + public void testValidScalar() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayF); + Operand weights = tf.constant(5f); + Float[] expected = { + 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, + 5f + }; + testValid(testSession, tf, weights, values, expected, TFloat32.class); + } + } + + @Test + public void test1x1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayD); + Operand weights = tf.constant(new double[][][] {{{5}}}); + Double[] expected = { + 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., + 5. + }; + + testValid(testSession, tf, weights, values, expected, TFloat64.class); + } + } + + @Test + public void test1x1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayL); + Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + Long[] expected = { + 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, + 11L, 3L, + }; + testValid(testSession, tf, weights, values, expected, TInt64.class); + } + } + + @Test + public void test1xNx1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + Integer[] expected = { + 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void test1xNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + Integer[] expected = { + 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void testNx1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + Integer[] expected = { + 5, 5, 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void testNx1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + Integer[] expected = { + 5, 7, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void testNxNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + + Operand weights = + tf.constant( + new int[][][] { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + }); + Integer[] expected = { + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + // Note: For invalid tests, either NotBroadcastableException is thrown for static shapes or + // TFInvalidInvalidException is thrown for dynamic shapes. Both of these extend + // IllegalArgumentException, + // To simply the assertThrows, only IllegalArgumentException is expected. + // The private method, testValid, tests for both static and dynamic shapes. + @Test + public void testInvalid1() { + + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[] {5}); + + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalid1x1() { + + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5}}); + + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidOnesExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + + Operand weights = + tf.constant( + new int[][][][] { + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + }); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant( + new int[][][][] { + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } + }); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 33c4e064e69..8e401c21627 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -1025,7 +1025,7 @@ public void print(PrintWriter writer, Output input) { (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %f\n", index.getAndIncrement(), ((Output) input).asTensor().getDouble()); + "%d). %f\n", index.getAndIncrement(), result.getDouble()); } else { result .scalars() @@ -1040,7 +1040,7 @@ public void print(PrintWriter writer, Output input) { (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getInt()); + "%d). %d\n", index.getAndIncrement(),result.getInt()); } else { result .scalars() @@ -1055,7 +1055,7 @@ public void print(PrintWriter writer, Output input) { (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getLong()); + "%d). %d\n", index.getAndIncrement(), result.getLong()); } else { result .scalars() @@ -1070,7 +1070,7 @@ public void print(PrintWriter writer, Output input) { (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %x\n", index.getAndIncrement(), ((Output) input).asTensor().getByte()); + "%d). %x\n", index.getAndIncrement(), result.getByte()); } else { result .scalars() @@ -1085,7 +1085,7 @@ public void print(PrintWriter writer, Output input) { (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %b\n", index.getAndIncrement(), ((Output) input).asTensor().getBoolean()); + "%d). %b\n", index.getAndIncrement(), result.getBoolean()); } else { result .scalars() @@ -1100,7 +1100,7 @@ public void print(PrintWriter writer, Output input) { (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %s\n", index.getAndIncrement(), ((Output) input).asTensor().getObject()); + "%d). %s\n", index.getAndIncrement(), result.getObject()); } else { result .scalars() diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 3fccd0f0506..db39a330522 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -492,6 +492,16 @@ public void evaluate(FloatNdArray input, Predicate predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } + /** + * Print the input to standard out + * + + * @param input the operand to print + * @param the data type of the input + */ + public void print(Operand input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput()); + } /** * Print the input * @@ -503,6 +513,15 @@ public void print(OutputStream out, Operand input) { print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); } + /** + * Print the input to standard out + * + * @param input the op to print + */ + public void print(Op input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0)); + } + /** * Print the input * @@ -513,6 +532,16 @@ public void print(OutputStream out, Op input) { print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); } + /** + * Print the input to standard out + * + * @param input the op to print + * @param the data type of the input + */ + public void print(Output input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input); + } + /** * Print the input * From 3f46bd280caf5cff44165bcb5fa36c7c03aceb76 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 17 Jan 2021 13:09:32 -0500 Subject: [PATCH 22/56] Added comment to count to indicate that it may be weighted. --- .../java/org/tensorflow/framework/metrics/impl/Reduce.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 771f4804dea..8e48cb4e573 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -45,7 +45,8 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values */ + /** the variable that holds the count of the metric values. + * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ protected Variable count; /** From fe65ae765e510fd690798afe60b5092284c43d3e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 18 Jan 2021 20:22:24 -0500 Subject: [PATCH 23/56] Added SetsOps and fixed AssertBroadcastable to use SetsOps methods, --- .../framework/metrics/impl/MetricsHelper.java | 127 +++++++++++----- .../framework/metrics/impl/SetsOps.java | 141 ++++++++++++++++++ .../metrics/impl/AssertBroadcastableTest.java | 7 +- .../framework/metrics/impl/SetsOpsTest.java | 110 ++++++++++++++ .../framework/utils/GraphTestSession.java | 22 ++- 5 files changed, 364 insertions(+), 43 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index fbe50151854..ad8ff58e417 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -19,10 +19,10 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.SetDiff1d; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; @@ -33,6 +33,7 @@ import java.util.List; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; /** * These are helper methods for Metrics and will be module private when Java modularity is applied @@ -126,10 +127,17 @@ public static Op assertBroadcastable( tf.constant("isScalar="), isScalar); - Operand validNonsclar = + // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a + // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. + Operand reshapedWeights = + tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); + weightsShape = tf.shape(reshapedWeights); + weightsRank = tf.rank(reshapedWeights); + + Operand validNonscalar = canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); - Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); + Operand isValidShape = tf.select(isScalar, isScalar, validNonscalar); return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); } @@ -151,7 +159,7 @@ private static Operand canBroadcastNonscalarShapes( Operand weightsShape, Operand valuesRank, Operand valuesShape) { - tf = tf.withSubScope("hasValidNonscalarShape"); + tf = tf.withSubScope("canBroadcastNonscalarShapes"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); return tf.select(isSameRank, canBroadcastDims(tf, weightsShape, valuesShape), isSameRank); } @@ -167,22 +175,23 @@ private static Operand canBroadcastNonscalarShapes( */ private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { - tf = tf.withSubScope("hasValidDims"); + tf = tf.withSubScope("canBroadcastDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); - SetDiff1d invalidDimsDiff = tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); - Operand invalidDims = invalidDimsDiff.out(); - Operand numInvalidDims = tf.size(invalidDims); + Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); + + Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + Operand numInvalidDims = tf.size(diffResult); return tf.math.equal(tf.constant(0), numInvalidDims); } /** - * Broadcast `weights` to the same shape as `values`. + * Broadcast weights to the same shape as values. * * @param tf the TensorFlow ops - * @param weights `Tensor` whose shape is broadcastable to `values` - * @param values Tensor` of any shape + * @param weights Operand whose shape is broadcastable to values. + * @param values Operand of any shape * @param the type of Operands * @return weights broadcast to values shape */ @@ -205,7 +214,7 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - // alias for mean + // aliases for mean /** * Calculate the mean of the operand, along all axes and keepDims is false @@ -214,10 +223,9 @@ public static Operand broadcastWeights( * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param the type of the Operand. - * @param the data type for the result * @return the mean of the operand */ - public static Operand mean(Ops tf, Operand x) { + public static Operand mean(Ops tf, Operand x) { return mean(tf, x, null, false); } @@ -230,16 +238,15 @@ public static Operand mean(Ops tf, Opera * @param axes Axes to compute the mean. * @param the type of the Operand. * @param the type of the axes. - * @param the data type for the result * @return the mean of the operand, along the specified axes. */ - public static Operand mean( + public static Operand mean( Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } /** - * Calculate the mean of the operand, along all axes. + * Calculates the mean of the operand, along all axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -248,16 +255,17 @@ public static Operand< * . If keepdims is true, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @param the data type for the result * @return the mean of elements of x. */ - public static Operand mean( + public static Operand mean( Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } + + /** - * Calculate the mean of the operand, alongside the specified axes. + * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -267,23 +275,74 @@ public static Operand mean( * * reduced dimensions are retained with length 1. * @param the data type of the Operand * @param the data type of the axes - * @param the data type for the result - * @return the mean of elements of `x`. + * @return the mean of elements of x. */ - @SuppressWarnings({"unchecked", "rawtypes"}) - public static Operand mean( + + public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { - // Cannot use generics here because xf may change from TBool to TFloat32 - Operand xf; - if (x.type().equals(TBool.class)) { - xf = (Operand) tf.dtypes.cast(x, TFloat32.class); - } else { - xf = (Operand) x; - } if (axes == null) { - axes = (Operand) allAxes(tf, xf); + axes = (Operand) allAxes(tf, x); } - Operand theMean = tf.math.mean(xf, axes, Mean.keepDims(keepDims)); - return x.type().equals(TBool.class) ? tf.dtypes.cast(theMean, TBool.class) : theMean; + return tf.math.mean(x, axes, Mean.keepDims(keepDims)); + } + + /** + * Calculate the mean of the operand, along all axes and keepDims is false + * + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @return the mean of the operand containing floating point numbers + */ + public static Operand booleanMean(Ops tf, Operand x) { + return booleanMean(tf, x, null, false); } + + /** + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param the type of the axes. + * @return the mean of the operand, along the specified axes, containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x,Operand axes) { + return booleanMean(tf, x, axes, false); + } + + /** + * Calculates the mean of the boolean operand, alongside all axes. + * + * @param x the boolean Operand used to calculate the mean + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the axes + * @return the mean of elements of x containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x, boolean keepDims) { + return booleanMean(tf, x, null, keepDims); + } + + /** + * Calculates the mean of the boolean operand, alongside the specified axes. + * + * @param x the boolean Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the axes + * @return the mean of elements of x containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { + Operand xf = cast(tf, x, TFloat64.class); + return mean(tf, xf, axes, keepDims); + } + } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java new file mode 100644 index 00000000000..236b3d9084d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.SparseOps; +import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Implementation of set operations */ +public class SetsOps { + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } + + /** + * Computes set difference of elements in last dimension of a and b with + * aMinusB set to true/ + * + *

All but the last dimension of a and b must match + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand difference(Ops tf, Operand a, Operand b) { + return difference(tf, a, b, true); + } + /** + * Computes set difference of elements in last dimension of a and b. + * + *

All but the last dimension of a and b must match + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param aMinusB whether to subtract b from a, vs vice versa. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand difference( + Ops tf, Operand a, Operand b, boolean aMinusB) { + return setOperation(tf, a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + } + + /** + * Computes set union of elements in last dimension of a and b. + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand union(Ops tf, Operand a, Operand b) { + return setOperation(tf, a, b, Operation.UNION); + } + + /** + * Computes set intersection of elements in last dimension of a and b. + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand intersection(Ops tf, Operand a, Operand b) { + return setOperation(tf, a, b, Operation.INTERSECTION); + } + + /** + * Compute set operation of elements in last dimension of a and b. + * + * @param tf the TensorFlow Ops + * @param a The first set operation operand + * @param b The other et operation operand + * @param setOperation The set operation to perform, {@link Operation}. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand setOperation( + Ops tf, Operand a, Operand b, Operation setOperation) { + + DenseToDenseSetOperation setOperationResult = + tf.sparse.denseToDenseSetOperation( + a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); + return setOperationResult.resultValues(); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index af4a89692d1..63d666f8640 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -72,7 +72,6 @@ private void testValid( testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); try (Tensor weightsTensor = tensors.get(0); Tensor valuesTensor = tensors.get(1)) { - Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession @@ -90,6 +89,7 @@ public void testValidScalar() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayF); Operand weights = tf.constant(5f); testValid(testSession, tf, weights, values, TFloat32.class); @@ -101,6 +101,7 @@ public void test1x1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayD); Operand weights = tf.constant(new double[][][] {{{5}}}); testValid(testSession, tf, weights, values, TFloat64.class); @@ -134,6 +135,7 @@ public void test1xNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -145,6 +147,7 @@ public void testNx1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -156,6 +159,7 @@ public void testNx1xN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); @@ -168,6 +172,7 @@ public void testNxNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java new file mode 100644 index 00000000000..5250c22d740 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java @@ -0,0 +1,110 @@ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +class SetsOpsTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + List> types = Arrays.asList(TInt32.class, TInt64.class, TUint8.class); + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testSetIntersectionMultirow2() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); + Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); + Integer[] expected = new Integer[] {1, 9}; + Shape expectedShape = Shape.of(2); + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + Operand intersection = SetsOps.intersection(tf, aa, bb); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + } + } + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testSetIntersectionDuplicates2d() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{1, 1, 3}}); + Operand b = tf.constant(new int[][] {{1}}); + Integer[] expected = new Integer[] {1}; + Shape expectedShape = Shape.of(1); + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + Operand intersection = SetsOps.intersection(tf, aa, bb); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + } + } + } + + public void testDenseSetDifferenceMultirow2d() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); + Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); + Integer[] expected = new Integer[] {5, 9, 3, 4, 5}; + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + // a- b + Operand intersection = SetsOps.difference(tf, aa, bb); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(5L), tf.shape(intersection, TInt64.class)); + + // b - a + expected = new Integer[] {2, 6, 1, 2}; + intersection = SetsOps.difference(tf, aa, bb, false); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(4L), tf.shape(intersection, TInt64.class)); + } + } + } + + public void testDenseUnionMultirow2d() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); + Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); + Integer[] expected = new Integer[] {1, 5, 9, 1, 2, 3, 4}; + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + // a- b + Operand intersection = SetsOps.difference(tf, aa, bb); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(7L), tf.shape(intersection, TInt64.class)); + + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 8e401c21627..43c0642939e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -213,10 +213,13 @@ public void evaluate(double expected, Operand input) { @Override public void evaluate(Number[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } Class inputType = input.type(); if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); @@ -425,10 +428,13 @@ public void evaluate(FloatNdArray expected, Output input) { @Override public void evaluate(String[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } AtomicInteger index = new AtomicInteger(); if (debug) { try (TString result = From f8d38cf3db16ebecd46c5a9d3dddce688f6dcc88 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 18 Jan 2021 20:22:46 -0500 Subject: [PATCH 24/56] Fixed based on various PR comments. --- .../framework/metrics/BinaryCrossentropy.java | 2 +- .../framework/metrics/CategoricalCrossentropy.java | 4 ++-- .../org/tensorflow/framework/metrics/Metrics.java | 10 +++++----- .../metrics/SparseCategoricalCrossentropy.java | 12 ++++++------ .../framework/metrics/impl/MeanMetricWrapper.java | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index c339b977007..651a6fac0b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -41,7 +41,7 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 7b8cf0054a4..c330ea88eaa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -48,7 +48,7 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +68,7 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index e2cd5e368c2..0169bc6b8bc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -64,19 +64,19 @@ public static Operand topKCategoricalA * @param tf the TensorFlow Ops * @param labels The ground truth values. * @param predictions The prediction values. - * @param axis The dimension along which the cosine similarity is computed. + * @param axes The dimensions along which the cosine similarity is computed. * @param the data type for the labels * @param the data type for the predictions and result * @return Cosine similarity value. */ public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axis) { + Ops tf, Operand labels, Operand predictions, int[] axes) { Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axis); + labelsNorm = l2Normalize(tf, labelsNorm, axes); - Operand predictionsNorm = l2Normalize(tf, predictions, axis); + Operand predictionsNorm = l2Normalize(tf, predictions, axes); Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); + return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 3fde8b2ecf6..2e01f722de6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -32,30 +32,30 @@ public class SparseCategoricalCrossentropy extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; - private final int axes; + private final int axis; /** * Creates a SparseCategoricalCrossentropy metric * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. - * @param axes The dimension along which the entropy is computed. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, int axes, long seed, Class type) { + Ops tf, String name, boolean fromLogits, int axis, long seed, Class type) { super(tf, name, seed, type); setLoss(this); this.fromLogits = fromLogits; - this.axes = axes; + this.axis = axis; } /** {@inheritDoc} */ @Override public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axes); + return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index e2f1345f356..17c209a8fed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -85,7 +85,7 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the predictions + * @param the datatype of the labels * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ From 00ce5db0bdf97f1cbf52fa225d28c8bbef7422ac Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 18 Jan 2021 20:23:44 -0500 Subject: [PATCH 25/56] Deleted, no longer needed after change to Variable handling in Metrics. --- .../metrics/impl/MetricVariable.java | 125 ------------------ 1 file changed, 125 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java deleted file mode 100644 index 6b208c0d7bf..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import org.tensorflow.Operand; -import org.tensorflow.framework.initializers.Glorot; -import org.tensorflow.framework.initializers.Initializer; -import org.tensorflow.framework.initializers.VarianceScaling; -import org.tensorflow.framework.initializers.Zeros; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.family.TFloating; -import org.tensorflow.types.family.TIntegral; -import org.tensorflow.types.family.TNumber; - -/** - * Helper class that holds a metric variable - * - * @param the data type of the variable - */ -public class MetricVariable { - private final Variable variable; - private final Initializer initializer; - private final Ops tf; - private boolean initialized; - - /** - * Creates a Metric Variable - * - * @param tf the TensorFlow Ops - * @param variable the variable - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variable - */ - public MetricVariable(Ops tf, Variable variable, long seed, Class type) { - this(tf, variable, null, seed, type); - } - - /** - * Creates a Metric Variable - * - * @param tf the TensorFlow Ops - * @param variable the variable - * @param initializer the initializer for the variable, if null, then the default for floating - * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution - * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for - * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variable - * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer - * is null - */ - @SuppressWarnings("unchecked") - public MetricVariable( - Ops tf, Variable variable, Initializer initializer, long seed, Class type) { - this.tf = tf; - this.variable = variable; - - if (initializer == null) { - if (TFloating.class.isAssignableFrom(type)) { - //noinspection RedundantCast - this.initializer = - (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); - } else if (TIntegral.class.isAssignableFrom(type)) { - this.initializer = new Zeros<>(tf); - } else { - throw new IllegalArgumentException( - String.format("Type %s is not supported for metric variables", type.getSimpleName())); - } - } else { - this.initializer = initializer; - } - } - - /** - * Initializers the variable based on the initializer - * - * @return the initialized variable - */ - public Operand initialize() { - initialized = true; - return initializer.call(tf.constant(variable.shape()), variable.type()); - } - - /** - * Gets the variable - * - * @return the variable - */ - public Variable getVariable() { - return variable; - } - - /** - * Gets the initializer - * - * @return the initializer - */ - public Initializer getInitializer() { - return initializer; - } - - /** - * Gets the value of initialized - * - * @return the value of initialized - */ - public boolean isInitialized() { - return initialized; - } -} From 51104f16328f0175fa8ae8262a980c9b1064d631 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 20 Jan 2021 14:02:22 -0500 Subject: [PATCH 26/56] Fix Losses to use CHANNELS_FIRST/LAST for CategoricalCrossentropy --- .../framework/losses/CategoricalCrossentropy.java | 7 ++++--- .../main/java/org/tensorflow/framework/losses/Losses.java | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 77c6ab2bf87..363291fa5cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -69,7 +69,7 @@ public class CategoricalCrossentropy extends Loss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; - public static final int DEFAULT_AXIS = -1; + public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST; private final boolean fromLogits; private final float labelSmoothing; @@ -203,8 +203,9 @@ public CategoricalCrossentropy( * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. - * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' - * and axis=1 corresponds to data format 'Channels First'. + * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" + * and axis=1 corresponds to data format "Channels First". + * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 3894bee0d0f..0d25bd5e7e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -242,7 +242,7 @@ public static Operand categoricalCross tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing); } if (fromLogits) { - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); } /* TODO if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { From e918df416fecf40b458ccb0ba66c1c29c628aeb0 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 30 Jan 2021 15:20:10 -0500 Subject: [PATCH 27/56] Fix SetOps to properly convert sparse tensor to dense tensor using tf.sparse.sparseToDense with the output of tf.sparse.denseToDenseSetOperation --- .../framework/metrics/impl/SetsOps.java | 9 +++- .../framework/metrics/impl/SetsOpsTest.java | 46 +++++++++++-------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 236b3d9084d..1841c7ee238 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -54,7 +54,7 @@ public String getSetOperation() { /** * Computes set difference of elements in last dimension of a and b with - * aMinusB set to true/ + * aMinusB set to true. * *

All but the last dimension of a and b must match * @@ -136,6 +136,11 @@ public static Operand setOperation( DenseToDenseSetOperation setOperationResult = tf.sparse.denseToDenseSetOperation( a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - return setOperationResult.resultValues(); + + return tf.sparse.sparseToDense( + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + cast(tf, tf.constant(0), a.type())); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java index 5250c22d740..eceff2797f8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java @@ -30,13 +30,13 @@ public void testSetIntersectionMultirow2() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); - Integer[] expected = new Integer[] {1, 9}; - Shape expectedShape = Shape.of(2); + int[][] expected = new int[][] {{1, 9}, {0, 0}}; + Shape expectedShape = Shape.of(2, 2); for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); - session.evaluate(expected, intersection); + Operand intersection = SetsOps.intersection(tf, aa, bb); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } @@ -50,19 +50,23 @@ public void testSetIntersectionDuplicates2d() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{1, 1, 3}}); - Operand b = tf.constant(new int[][] {{1}}); - Integer[] expected = new Integer[] {1}; - Shape expectedShape = Shape.of(1); + Operand b = tf.constant(new int[][] {{1, 1}}); + int[][] expected = {{1}}; + Shape expectedShape = Shape.of(1, 1); for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); Operand intersection = SetsOps.intersection(tf, aa, bb); - session.evaluate(expected, intersection); + + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) public void testDenseSetDifferenceMultirow2d() { for (TestSession.Mode tfMode : tfModes) @@ -70,24 +74,30 @@ public void testDenseSetDifferenceMultirow2d() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); - Integer[] expected = new Integer[] {5, 9, 3, 4, 5}; + for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); + int[][] expected = {{5, 9, 0}, {3, 4, 5}}; // a- b + Shape expectedShape = Shape.of(2, 3); Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(expected, intersection); - session.evaluate(tf.constant(5L), tf.shape(intersection, TInt64.class)); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); // b - a - expected = new Integer[] {2, 6, 1, 2}; + expected = new int[][] {{2, 6}, {1, 2}}; + expectedShape = Shape.of(2, 2); intersection = SetsOps.difference(tf, aa, bb, false); - session.evaluate(expected, intersection); - session.evaluate(tf.constant(4L), tf.shape(intersection, TInt64.class)); + + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) public void testDenseUnionMultirow2d() { for (TestSession.Mode tfMode : tfModes) @@ -95,15 +105,15 @@ public void testDenseUnionMultirow2d() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); - Integer[] expected = new Integer[] {1, 5, 9, 1, 2, 3, 4}; + int[][] expected = new int[][] {{5, 0}, {3, 4}}; for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); + Shape expectedShape = Shape.of(2, 2); // a- b Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(expected, intersection); - session.evaluate(tf.constant(7L), tf.shape(intersection, TInt64.class)); - + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } From 9cdc274f0c682260360ec8b5b4ed60d79076ced6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 31 Dec 2020 19:29:42 -0500 Subject: [PATCH 28/56] Initial checkin --- .../framework/metrics/BinaryCrossentropy.java | 16 +- .../metrics/CategoricalCrossentropy.java | 26 +- .../framework/metrics/CategoricalHinge.java | 12 +- .../framework/metrics/CosineSimilarity.java | 39 +-- .../tensorflow/framework/metrics/Hinge.java | 12 +- .../framework/metrics/KLDivergence.java | 13 +- .../framework/metrics/LogCoshError.java | 12 +- .../tensorflow/framework/metrics/Mean.java | 5 +- .../framework/metrics/MeanAbsoluteError.java | 12 +- .../metrics/MeanAbsolutePercentageError.java | 12 +- .../framework/metrics/MeanSquaredError.java | 12 +- .../metrics/MeanSquaredLogarithmicError.java | 12 +- .../tensorflow/framework/metrics/Metric.java | 199 ++++++++++--- .../tensorflow/framework/metrics/Metrics.java | 80 ++++- .../tensorflow/framework/metrics/Poisson.java | 12 +- .../SparseCategoricalCrossentropy.java | 25 +- .../SparseTopKCategoricalAccuracy.java | 65 +++++ .../framework/metrics/SquaredHinge.java | 12 +- .../metrics/TopKCategoricalAccuracy.java | 63 ++++ .../framework/metrics/impl/LossInterface.java | 36 +++ .../metrics/impl/MeanMetricWrapper.java | 57 ++-- .../metrics/impl/MetricVariable.java | 122 ++++++++ .../framework/metrics/impl/MetricsHelper.java | 274 +++--------------- .../framework/metrics/impl/Reduce.java | 128 +++----- .../metrics/BinaryCrossentropyTest.java | 21 +- .../metrics/CosineSimilarityTest.java | 2 +- .../framework/metrics/KLDivergenceTest.java | 2 +- .../SparseTopKCategoricalAccuracyTest.java | 96 ++++++ .../metrics/TopKCategoricalAccuracyTest.java | 103 +++++++ 29 files changed, 925 insertions(+), 555 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 651a6fac0b0..d13d20bfdee 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -16,22 +16,19 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** - * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. - * - *

This is the crossentropy metric class to be used when there are only two label classes (0 and - * 1). + * Computes the binary cross-entropy loss between true labels and predicted labels. * * @param the data type for the predictions. * @param The data type for the metric result */ public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { + extends MeanMetricWrapper implements LossInterface { private final boolean fromLogits; private final float labelSmoothing; @@ -39,16 +36,19 @@ public class BinaryCrossentropy /** * Creates a BinaryCrossentropy metric * + *

This is the crossentropy metric class to be used when there are only two label classes (0 + * and 1). + * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing * correspond to heavier smoothing. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result + * @param type the data type for the variables */ public BinaryCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index c330ea88eaa..cf9ecd0858a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -16,46 +16,41 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** - * A Metric that computes the categorical cross-entropy loss between true labels and predicted - * labels. + * Computes the categorical cross-entropy loss between true labels and predicted labels. * *

This is the crossentropy metric class to be used when there are multiple label classes (2 or - * more). The labels should be given as a one_hot representation. eg., When labels values are - * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * more). Here we assume that labels are given as a one_hot representation. eg., When labels values + * are [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . - * - * @param the data type for the predictions. - * @param The data type for the metric result */ public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { + extends MeanMetricWrapper implements LossInterface { private final boolean fromLogits; private final float labelSmoothing; - private final int axis; + private int axis; /** - * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the + * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the * labels and predictions. * *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 * for label 1 * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public CategoricalCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { @@ -63,12 +58,12 @@ public CategoricalCrossentropy( } /** - * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the + * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the * labels and predictions. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -79,7 +74,6 @@ public CategoricalCrossentropy( * channels_first. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public CategoricalCrossentropy( Ops tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 2741a36edb6..a9500b79d9e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A Metric that computes the categorical hinge loss metric between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result - */ +/** Computes the categorical hinge loss metric between labels and predictions. */ public class CategoricalHinge extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a CategoricalHinge metric @@ -37,7 +32,6 @@ public class CategoricalHinge extends Mean * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public CategoricalHinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 458de092bec..61802572c7b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,49 +15,35 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the cosine similarity metric between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ -public class CosineSimilarity extends MeanMetricWrapper - implements LossMetric { - public static final int DEFAULT_AXIS = -1; +/** Computes the cosine similarity metric between labels and predictions. */ +// TODO: this is weird, the metric is called CosineSimilarity in Keras, +// but it calls Metrics.cosineProximity instead of Losses.cosineSimilarity. +// The metric is calculating the Euclidean distance using L2 norms, while the loss +// is using the dot product proportional to the product of their magnitudes. +// While the 2 concepts are similar, they are different. +// Should we rename this metric to CosineProximity? +public class CosineSimilarity extends MeanMetricWrapper + implements LossInterface { + public static final int[] DEFAULT_AXIS = {-1}; private final int[] axis; /** - * Creates a metric that computes the cosine similarity metric between labels and predictions with - * a default axis, {@link #DEFAULT_AXIS} + * Creates a CosineSimilarity metric with a default axis, {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_AXIS, seed, type); } - /** - * Creates a metric that computes the cosine similarity metric between labels and predictions. - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param axis The dimension along which the cosine similarity is computed. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result - */ - public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { - this(tf, name, new int[] {axis}, seed, type); - } /** * Creates a CosineSimilarity metric * @@ -66,7 +52,6 @@ public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index baf9ad8ab7d..d655f8d8237 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the hinge loss metric between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the hinge loss metric between labels and predictions. */ public class Hinge extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a Hinge metric @@ -37,7 +32,6 @@ public class Hinge extends MeanMetricWrapp * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public Hinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index efcbbcbb7f0..3f31383381a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -16,20 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the Kullback-Leibler divergence loss metric between labels and - * predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes Computes Kullback-Leibler divergence loss metric between labels and predictions. */ public class KLDivergence extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a KLDivergence metric @@ -38,7 +32,6 @@ public class KLDivergence extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public KLDivergence(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 3df8505d54b..7d4b8a9fad7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -16,20 +16,17 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; /** - * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric - * between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. + * Computes the logarithm of the hyperbolic cosine of the prediction error metric between labels and + * predictions. */ public class LogCoshError extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a LogCoshError metric @@ -38,7 +35,6 @@ public class LogCoshError extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public LogCoshError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index de1f5a5629e..08d1083dd05 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -19,7 +19,7 @@ import org.tensorflow.types.family.TNumber; /** - * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } + * Represents a Metric that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * * @param The data type for the metric values * @param The data type for the metric result @@ -33,9 +33,8 @@ public class Mean extends Reduce { * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ protected Mean(Ops tf, String name, long seed, Class type) { - super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); + super(tf, name, seed, MetricReduction.WEIGHTED_MEAN, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index e27676932ff..6b29c72fe82 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the mean of absolute difference between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the mean of absolute difference between labels and predictions. */ public class MeanAbsoluteError extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a Mean Absolute Error metric @@ -37,7 +32,6 @@ public class MeanAbsoluteError extends Mea * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 84fa9b627b2..6209245d881 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the mean of absolute difference between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the mean of absolute difference between labels and predictions. */ public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { + extends MeanMetricWrapper implements LossInterface { /** * Creates a Mean Absolute Error metric @@ -37,7 +32,6 @@ public class MeanAbsolutePercentageError * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c7edd6ebe93..ce30e378e8d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the mean of absolute difference between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the mean of absolute difference between labels and predictions. */ public class MeanSquaredError extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a Mean Absolute Error metric @@ -37,7 +32,6 @@ public class MeanSquaredError extends Mean * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public MeanSquaredError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 199b6e0e114..9baeac2f320 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the mean of absolute difference between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the mean of absolute difference between labels and predictions. */ public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { + extends MeanMetricWrapper implements LossInterface { /** * Creates a Mean Absolute Error metric @@ -37,7 +32,6 @@ public class MeanSquaredLogarithmicError * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index bbb2aa73da2..62ec5439269 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -1,3 +1,20 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.metrics; + /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,15 +29,18 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.framework.metrics; +import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.metrics.impl.MetricVariable; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.Collections; -import java.util.List; +import java.util.*; +import java.util.stream.Collectors; /** * Base class for Metrics @@ -30,24 +50,41 @@ */ public abstract class Metric { + /** variables are stored by ExecutionEnvironment, and then by an identifier name */ + protected static Map>> + variableMap = new WeakHashMap<>(); /** The TensorFlow Ops */ private final Ops tf; - - /** The seed for random number generation */ + /** The random number generator seed value */ private final long seed; + // TODO: how to handle variables across new ExecutionEnvironments. + // Metrics may be instantiated multiple times using the same variables, + // These variables become stale when a new ExecutionEnvironment is created + // (most commonly seen in Unit Tests), so the question is how to best handle this. + // Option 1, which is used here is to map the variables against an instance of + // an ExecutionEnvironment in a WeakHashMap, when a new ExecutionEnvironment is presented, the + // new + // variables are mapped to it. A WeakHashMap is used to throw away the old ExecutionEnvironment + // mappings, when the old ExecutionEnvironment is finalized. + // Option 2, keep an instance of the newly presented ExecutionEnvironment and if it changes, + // clear the variable maps. + // My guess is that in a non-unit test environment, only one ExecutionEnvironment will be used, + // I welcome thoughts on this. /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ private final String name; + private final Class type; + /** - * Creates a Metric with a name of {@link Class#getSimpleName()} + * Creates a Metric with a name of {@link Class#getSimpleName()} } * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - protected Metric(Ops tf, long seed) { - this(tf, null, seed); + protected Metric(Ops tf, long seed, Class type) { + this(tf, null, seed, type); } /** @@ -58,13 +95,13 @@ protected Metric(Ops tf, long seed) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - protected Metric(Ops tf, String name, long seed) { - if (!tf.scope().env().isGraph()) { + protected Metric(Ops tf, String name, long seed, Class type) { + if (!tf.scope().env().isGraph()) throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); - } this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); - this.tf = tf.withName(this.getClass().getSimpleName()); + this.tf = tf.withSubScope(this.name); + this.type = type; } /** @@ -75,10 +112,9 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state - * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -90,13 +126,12 @@ public List updateStateList(Operand values, Operand the data type for the labels - * @param the data type for the sampleWeights + * @param the data type for the sample weights * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + @SuppressWarnings({"unchecked","unused"}) + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -105,10 +140,9 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -119,12 +153,10 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the labels - * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, Operand predictions, Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -132,16 +164,19 @@ public final Op updateState( /** * Gets the current result of the metric * + * @param tf the TensorFlow Ops used to create the result * @return the result, possibly with control dependencies */ - public abstract Operand result(); + public abstract Operand result(Ops tf); /** - * Resets any state variables to their initial values + * Gets the current result of the metric using the metric's {@link #getTF()} * - * @return the control operation for doing the reset + * @return the result, possibly with control dependencies */ - public abstract Op resetStates(); + public Operand result() { + return result(this.tf); + } /** * Calls update state once, followed by a call to get the result @@ -149,13 +184,41 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies - * @param the data type for the sampleWeights. */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); - return ltf.identity(result()); + return result(ltf); + } + + /** + * Adds a variable to collect metric values + * + * @param variable the variable + * @param initializer the initializer for the variable, if null, then the default for floating + * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution + * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for + * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} + */ + protected void addVariable( + String varName, Variable variable, Initializer initializer) { + // TODO option 2 would be to keep track of tf.scope().env() and if it changes, clear to old Map. + Map> variables = + variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); + variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed)); + } + + /** + * Gets the list of added variables + * + * @return the list of added variables + */ + public List> getVariables() { + List> result = new ArrayList<>(); + Map> variables = variableMap.get(tf.scope().env()); + if (variables != null) variables.values().forEach(mv -> result.add(mv.getVariable())); + return result; } /** @@ -168,6 +231,71 @@ protected String getVariableName(String varName) { return String.format("%s_%s", this.name, varName); } + /** + * Gets an Operation that initializes the variables. + * + * @param subScopeName the sub scope name + * @return the Operation used to initialize the variables. + */ + public Op initialize(String subScopeName) { + + List initializeOperations = initializeVarsList(subScopeName); + return tf.withControlDependencies(initializeOperations).noOp(); + } + + /** + * Gets the list of Operations that initializes the variables + * + * @param subScopeName the sub scope name + * @return the list of Operations that initializes the variables + */ + @SuppressWarnings("unchecked") + private List initializeVarsList(String subScopeName) { + Map> variables = variableMap.get(tf.scope().env()); + if (variables != null) + return variables.values().stream() + .map(metricVariable -> variableAssign(subScopeName, metricVariable)) + .collect(Collectors.toList()); + else return Collections.EMPTY_LIST; + } + + /** + * Resets all variables to their initial state + * + * @return An Operation that sets all variables to their initial state + */ + public Op resetStates() { + return initialize("resetStates"); + } + + /** + * Assigns a value to a Variable + * + *

This assumes the variable has already been initialized + * + * @param subScopeName the subscope for creating the variable + * @param mv the metric value used to assign the initializer to the variable. + * @return the variable add operation with necessary control dependencies + */ + private Operand variableAssign( + String subScopeName, MetricVariable mv) { + return tf.withSubScope(subScopeName).assign(mv.getVariable(), mv.initialize()); + } + + /** + * Gets a stored variable by name, Variables are cached first by the TensorFlow Environment, then + * by a variable name. + * + * @param varName the name assigned to the variable + * @return the variable, or null if the variable is not found. + */ + public Variable getVariable(String varName) { + Map> variables = variableMap.get(tf.scope().env()); + if (variables == null) return null; + MetricVariable mv = variables.get(varName); + return mv != null ? mv.getVariable() : null; + } + /** * Gets the TensorFlow Ops * @@ -186,8 +314,7 @@ public String getName() { return name; } - /** The random number generator seed value */ - public long getSeed() { - return seed; + public Class getType() { + return type; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 0169bc6b8bc..f4282bfd0a9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -16,12 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -/** Helper class with built-in metrics functions. */ +/** Built-in metrics functions. */ public class Metrics { public static final float L2_NORM_EPSILON = 1e-12f; @@ -38,7 +40,7 @@ public class Metrics { * {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); * Operand<TFloat32> m = Metrics.topKCategoricalAccuracy( * labels, predictions, 3) - * //m.shape().toString == "[2]" + * //m.asOutput().shape().toString == "[2]" * * * @param tf the TensorFlow Ops. @@ -58,25 +60,77 @@ public static Operand topKCategoricalA predictions.type()); } + /** + * Computes how often integer targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.asOutput().shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Sparse top K categorical accuracy value. + */ + @SuppressWarnings("unchecked") + public static Operand sparseTopKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, int k) { + Operand tLabels; + if (labels.type() != predictions.type()) + tLabels = CastHelper.cast(tf, labels, predictions.type()); + else tLabels = (Operand) labels; + + int predictionsRank = predictions.asOutput().shape().numDimensions(); + int labelsRank = tLabels.asOutput().shape().numDimensions(); + + Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + if (predictionsRank > 2) { + castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); + } + if (labelsRank > 1) { + tLabels = tf.shape.flatten(tLabels); + } + } + return CastHelper.cast( + tf, + tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), + predictions.type()); + } + /** * Computes the cosine similarity between labels and predictions. * * @param tf the TensorFlow Ops * @param labels The ground truth values. * @param predictions The prediction values. - * @param axes The dimensions along which the cosine similarity is computed. + * @param axis The dimension along which the cosine similarity is computed. * @param the data type for the labels * @param the data type for the predictions and result * @return Cosine similarity value. */ + @SuppressWarnings("unchecked") public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axes) { - Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axes); + Ops tf, Operand labels, Operand predictions, int[] axis) { + Operand labelsNorm; + if (labels.type() != predictions.type()) + labelsNorm = CastHelper.cast(tf, labels, predictions.type()); + else labelsNorm = (Operand) labels; + labelsNorm = l2Normalize(tf, labelsNorm, axis); - Operand predictionsNorm = l2Normalize(tf, predictions, axes); + Operand predictionsNorm = l2Normalize(tf, predictions, axis); Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } /** @@ -98,6 +152,8 @@ public static Operand cosineProximity( * @param The data type for x. * @return the normalized values of x. */ + // TODO this was tf.math.l2_normalize in TF Python + public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { return l2Normalize(tf, x, axes, L2_NORM_EPSILON); } @@ -117,18 +173,20 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param tf The TensorFlow ops * @param x The operand to normalize * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the - * divisor if norm < sqrt(epsilon). + * @param epsilon A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if + * `norm < sqrt(epsilon)`. * @param The data type for the values. * @return the normalized values of x. */ + // TODO this was tf.math.l2_normalize in TF Python public static Operand l2Normalize( Ops tf, Operand x, int[] axes, float epsilon) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); Operand y = tf.math.rsqrt( - tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); + tf.math.maximum( + squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); return tf.math.mul(x, y); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 75a2031fbb5..f5730b07f42 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the poisson loss metric between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the poisson loss metric between labels and predictions. */ public class Poisson extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a Poisson metric @@ -37,7 +32,6 @@ public class Poisson extends MeanMetricWra * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public Poisson(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 2e01f722de6..403e11af8c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -16,46 +16,39 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the sparse categorical cross-entropy loss between true labels and predicted labels. */ public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { + extends MeanMetricWrapper implements LossInterface { private final boolean fromLogits; - private final int axis; + private final int axes; /** * Creates a SparseCategoricalCrossentropy metric * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. - * @param axis The dimension along which the entropy is computed. + * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param axes The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, int axis, long seed, Class type) { + Ops tf, String name, boolean fromLogits, int axes, long seed, Class type) { super(tf, name, seed, type); setLoss(this); this.fromLogits = fromLogits; - this.axis = axis; + this.axes = axes; } /** {@inheritDoc} */ @Override public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axes); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java new file mode 100644 index 00000000000..1412465bd89 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the poisson loss metric between labels and predictions. */ +public class SparseTopKCategoricalAccuracy + extends MeanMetricWrapper implements LossInterface { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Metrics.sparseTopKCategoricalAccuracy(getTF(), labels, predictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 430dbbcc229..7ce8091f2a0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -16,19 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.LossInterface; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * A metric that computes the squared hinge loss metric between labels and predictions. - * - * @param the data type for the predictions. - * @param The data type for the metric result. - */ +/** Computes the squared hinge loss metric between labels and predictions. */ public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { + implements LossInterface { /** * Creates a SquaredHinge metric @@ -37,7 +32,6 @@ public class SquaredHinge extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ public SquaredHinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java new file mode 100644 index 00000000000..3198ab0ee04 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Computes the poisson loss metric between labels and predictions. */ +public class TopKCategoricalAccuracy + extends MeanMetricWrapper implements LossInterface { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Metrics.topKCategoricalAccuracy(getTF(), labels, predictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java new file mode 100644 index 00000000000..aadc211c3c4 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +/** + * Interface for Metrics that wrap Loss functions. + * + * @param The data type of the predictions. + */ +public interface LossInterface { + + /** + * Calculates the weighted loss between labels and predictions + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the labels. + * @return the loss + */ + Operand call(Operand labels, Operand predictions); +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 17c209a8fed..5e0023c4dbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -25,20 +25,19 @@ import java.util.List; /** - * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of - * {@link MetricReduction#WEIGHTED_MEAN}. + * Bridges a stateless loss function with the {@link Mean} metric using a reduction of {@link + * MetricReduction#WEIGHTED_MEAN}. * *

The loss function calculates the loss between the labels and predictions * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the predictions. - * @param The data type for the metric result + * @param the data type for the loss. */ public class MeanMetricWrapper extends Mean { /** The loss function interface */ - protected LossMetric loss; + protected LossInterface loss; /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN} @@ -47,7 +46,6 @@ public class MeanMetricWrapper extends Mea * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result */ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); @@ -58,7 +56,7 @@ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { * * @return the loss function. */ - public LossMetric getLoss() { + public LossInterface getLoss() { return loss; } @@ -67,7 +65,7 @@ public LossMetric getLoss() { * * @param loss the loss function. */ - protected void setLoss(LossMetric loss) { + public void setLoss(LossInterface loss) { this.loss = loss; } @@ -85,22 +83,43 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @param the data type for sampleWeights + * @param the datatype of the predictions * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { - if (labels == null || predictions == null) { + public List updateLossStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + if (labels == null || predictions == null) throw new IllegalArgumentException("missing required inputs for labels and predictions"); - } - Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + Class type = predictions.type(); + Operand tPredicitons = CastHelper.cast(getTF(), predictions, getType()); + + Operand losses = loss.call(labels, tPredicitons); + Operand uLossess = CastHelper.cast(getTF(), losses, type); - Operand losses = loss.call(tLabels, tPredictions); + return super.updateStateList(uLossess, sampleWeights); + } - return super.updateStateList( - CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + /** + * Creates a Control Operation that updates the state of the mean metric by calculating the loss + * between the labels and predictions and then applying a weighted mean + * metric across the multiple iterations. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sampleWeights vector. If the shape of sampleWeights is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param the datatype of the labels + * @return a NoOp with control dependencies that update the state of the mean metric. + */ + public final Op updateLossState( + Operand labels, Operand predictions, Operand sampleWeights) { + List controlOps = updateLossStateList(labels, predictions, sampleWeights); + return getTF().withSubScope("updateState").withControlDependencies(controlOps).noOp(); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java new file mode 100644 index 00000000000..78d7459697c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.initializers.VarianceScaling; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TIntegral; +import org.tensorflow.types.family.TNumber; + +/** + * Helper class that holds a metric variable + * + * @param the data type of the variable + */ +// TODO handle distributed variables with VariableAggregation and VariableSynchronization +public class MetricVariable { + private final Variable variable; + private final Initializer initializer; + private final Ops tf; + private boolean initialized; + + /** + * Creates a Metric Variable + * + * @param tf the TensorFlow Ops + * @param variable the variable + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public MetricVariable(Ops tf, Variable variable, long seed) { + this(tf, variable, null, seed); + } + /** + * Creates a Metric Variable + * + * @param tf the TensorFlow Ops + * @param variable the variable + * @param initializer the initializer for the variable, if null, then the default for floating + * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution + * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for + * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + @SuppressWarnings("unchecked") + public MetricVariable(Ops tf, Variable variable, Initializer initializer, long seed) { + this.tf = tf; + this.variable = variable; + + Class type = variable.type(); + if (initializer == null) { + if (TFloating.class.isAssignableFrom(type)) { + this.initializer = + (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + } else if (TIntegral.class.isAssignableFrom(type)) { + this.initializer = new Zeros<>(tf); + } else { + throw new IllegalArgumentException( + String.format( + "An initializer for variable %s of type %s is required", + variable.toString(), type)); + } + } else { + this.initializer = initializer; + } + } + + /** + * Initializers the variable based on the initializer + * + * @return the initialized variable + */ + public Operand initialize() { + initialized = true; + return initializer.call(tf.constant(variable.asOutput().shape()), variable.type()); + } + + /** + * Gets the variable + * + * @return the variable + */ + public Variable getVariable() { + return variable; + } + + /** + * Gets the initializer + * + * @return the initializer + */ + public Initializer getInitializer() { + return initializer; + } + + /** + * Gets the value of initialized + * + * @return the value of initialized + */ + public boolean isInitialized() { + return initialized; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index ad8ff58e417..5395cccf4a7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,79 +15,63 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; -import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; import java.util.List; -import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * These are helper methods for Metrics and will be module private when Java modularity is applied * to TensorFlow Java. These methods should not be used outside of the metrics packages. */ public class MetricsHelper { - public static final float NEG_INF = -1e10f; - private static final String ASSERT_BROADCAST_ERROR_PREFIX = + private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to the same shape as values - * - * - *

This is the crossentropy metric class to be used when there are multiple label classes (2 or - * more). Here we assume that labels are given as a one_hot representation. eg., When labels values - * are [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * more). The labels should be given as a one_hot representation. eg., When labels values are + * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . + * + * @param the data type for the predictions. + * @param The data type for the metric result */ public class CategoricalCrossentropy extends MeanMetricWrapper implements LossInterface { @@ -51,6 +55,7 @@ public class CategoricalCrossentropy * for label 1 * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CategoricalCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index a9500b79d9e..21f19d88ade 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the categorical hinge loss metric between labels and predictions. */ +/** + * A Metric that computes the categorical hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ public class CategoricalHinge extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class CategoricalHinge extends Mean * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CategoricalHinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 4a5214aea8d..9ceccf7fc13 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -20,32 +20,33 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the cosine similarity metric between labels and predictions. */ -// TODO: this is weird, the metric is called CosineSimilarity in Keras, -// but it calls Metrics.cosineProximity instead of Losses.cosineSimilarity. -// The metric is calculating the Euclidean distance using L2 norms, while the loss -// is using the dot product proportional to the product of their magnitudes. -// While the 2 concepts are similar, they are different. -// Should we rename this metric to CosineProximity? +/** + * A metric that computes the cosine similarity metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class CosineSimilarity extends MeanMetricWrapper implements LossInterface { public static final int DEFAULT_AXIS = -1; private final int[] axis; /** - * Creates a CosineSimilarity metric with a default axis, {@link #DEFAULT_AXIS} + * Creates a metric that computes the cosine similarity metric between labels and predictions with + * a default axis, {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_AXIS, seed, type); } /** - * Creates a CosineSimilarity metric + * Creates a metric that computes the cosine similarity metric between labels and predictions. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index d655f8d8237..b276f0b9426 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the hinge loss metric between labels and predictions. */ +/** + * A metric that computes the hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class Hinge extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class Hinge extends MeanMetricWrapp * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public Hinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 3f31383381a..a3cbc6f16e6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,7 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes Computes Kullback-Leibler divergence loss metric between labels and predictions. */ +/** + * A metric that computes the Kullback-Leibler divergence loss metric between labels and + * predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class KLDivergence extends MeanMetricWrapper implements LossInterface { @@ -32,6 +38,7 @@ public class KLDivergence extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public KLDivergence(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 7d4b8a9fad7..d6fe903f5a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -22,8 +22,11 @@ import org.tensorflow.types.family.TNumber; /** - * Computes the logarithm of the hyperbolic cosine of the prediction error metric between labels and - * predictions. + * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric + * between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. */ public class LogCoshError extends MeanMetricWrapper implements LossInterface { @@ -35,6 +38,7 @@ public class LogCoshError extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public LogCoshError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index c68a70902a7..de1f5a5629e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -19,7 +19,7 @@ import org.tensorflow.types.family.TNumber; /** - * Represents a Metric that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } + * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * * @param The data type for the metric values * @param The data type for the metric result @@ -33,7 +33,7 @@ public class Mean extends Reduce { * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the result. + * @param type the type for the variables and result */ protected Mean(Ops tf, String name, long seed, Class type) { super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 6b29c72fe82..79da80ef191 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanAbsoluteError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanAbsoluteError extends Mea * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 6209245d881..558c194074f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanAbsolutePercentageError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanAbsolutePercentageError * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index ce30e378e8d..10704d14bd4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanSquaredError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanSquaredError extends Mean * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanSquaredError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 9baeac2f320..585fc312e5a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the mean of absolute difference between labels and predictions. */ +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class MeanSquaredLogarithmicError extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class MeanSquaredLogarithmicError * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 28a2ae0fa94..89e5436ed0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -42,19 +42,6 @@ public abstract class Metric { /** The random number generator seed value */ private final long seed; - // TODO: how to handle variables across new ExecutionEnvironments. - // Metrics may be instantiated multiple times using the same variables, - // These variables become stale when a new ExecutionEnvironment is created - // (most commonly seen in Unit Tests), so the question is how to best handle this. - // Option 1, which is used here is to map the variables against an instance of - // an ExecutionEnvironment in a WeakHashMap, when a new ExecutionEnvironment is presented, the - // new - // variables are mapped to it. A WeakHashMap is used to throw away the old ExecutionEnvironment - // mappings, when the old ExecutionEnvironment is finalized. - // Option 2, keep an instance of the newly presented ExecutionEnvironment and if it changes, - // clear the variable maps. - // My guess is that in a non-unit test environment, only one ExecutionEnvironment will be used, - // I welcome thoughts on this. /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ private final String name; @@ -185,7 +172,6 @@ public final Operand callOnce( */ protected void addVariable( String varName, Variable variable, Initializer initializer) { - // TODO option 2 would be to keep track of tf.scope().env() and if it changes, clear to old Map. Map> variables = variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed, variable.type())); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index c3c44ef6134..8a8ddf3694c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -16,14 +16,12 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -/** Built-in metrics functions. */ +/** Helper class with built-in metrics functions. */ public class Metrics { public static final float L2_NORM_EPSILON = 1e-12f; @@ -60,54 +58,6 @@ public static Operand topKCategoricalA predictions.type()); } - /** - * Computes how often integer targets are in the top K predictions. - * - *

Standalone usage: - * - *

-   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
-   *     Operand<TFloat32> predictions = tf.constant(new float[][]
-   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
-   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
-   *                                    labels, predictions, 3)
-   *     //m.shape().toString == "[2]"
-   * 
- * - * @param tf the TensorFlow Ops. - * @param labels the ground truth values. - * @param predictions The prediction values. - * @param k Number of top elements to look at for computing accuracy. - * @param the data type for the predictions and results - * @param the data type ofr the labels. - * @return the Operand for the Sparse top K categorical accuracy value. - */ - @SuppressWarnings("unchecked") - public static Operand sparseTopKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, int k) { - Operand tLabels; - if (labels.type() != predictions.type()) - tLabels = CastHelper.cast(tf, labels, predictions.type()); - else tLabels = (Operand) labels; - - int predictionsRank = predictions.shape().numDimensions(); - int labelsRank = tLabels.shape().numDimensions(); - - Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); - if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { - if (predictionsRank > 2) { - castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); - } - if (labelsRank > 1) { - tLabels = tf.shape.flatten(tLabels); - } - } - return CastHelper.cast( - tf, - tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), - predictions.type()); - } - /** * Computes the cosine similarity between labels and predictions. * @@ -152,10 +102,7 @@ public static Operand cosineProximity( * @param The data type for x. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python, does it belong here? - - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { return l2Normalize(tf, x, axes, L2_NORM_EPSILON); } @@ -179,7 +126,6 @@ public static Operand l2Normalize( * @param The data type for the values. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python, does it belong here? public static Operand l2Normalize( Ops tf, Operand x, int[] axes, float epsilon) { Operand squareSum = @@ -189,5 +135,4 @@ public static Operand l2Normalize( tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); return tf.math.mul(x, y); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index f5730b07f42..07ab129eb08 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the poisson loss metric between labels and predictions. */ +/** + * A metric that computes the poisson loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class Poisson extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class Poisson extends MeanMetricWra * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public Poisson(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 403e11af8c0..c2f916217e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,7 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the sparse categorical cross-entropy loss between true labels and predicted labels. */ +/** + * A metric that computes the sparse categorical cross-entropy loss between true labels and + * predicted labels. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class SparseCategoricalCrossentropy extends MeanMetricWrapper implements LossInterface { @@ -37,6 +43,7 @@ public class SparseCategoricalCrossentropy * @param axes The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public SparseCategoricalCrossentropy( Ops tf, String name, boolean fromLogits, int axes, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 7ce8091f2a0..d8c7aa097fe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,7 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** Computes the squared hinge loss metric between labels and predictions. */ +/** + * A metric that computes the squared hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ public class SquaredHinge extends MeanMetricWrapper implements LossInterface { @@ -32,6 +37,7 @@ public class SquaredHinge extends MeanMetr * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public SquaredHinge(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 77566e5c400..5894b24c4cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -25,8 +25,8 @@ import java.util.List; /** - * Bridges a stateless loss function with the {@link Mean} metric using a reduction of {@link - * MetricReduction#WEIGHTED_MEAN}. + * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of + * {@link MetricReduction#WEIGHTED_MEAN}. * *

In losses and metrics, limited weight broadcasting is supported. Weights must be either - * scalar, or the same rank as the target values, with each dimension either 1, or the same as the - * corresponding values dimension. + * Asserts that the sampleWeight can be broadcast to values * * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation with control dependencies to ensure sampleWeight - * can be broadcast to values + * @return Operation raising InvalidArgumentError if sampleWeight + * has incorrect shape. no_op if static checks determine + * sampleWeight has correct shape. * @param the type of Operand - * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. */ @SuppressWarnings("unchecked") - public static Op assertBroadcastable( + public static Op broadcastWeights( Ops tf, Operand sampleWeights, Operand values) { - // try static check for exact match - - Shape weightsShapeStatic = sampleWeights.shape(); + Operand weightsShape = tf.shape(sampleWeights); + Operand weightsRank = tf.rank(sampleWeights); + Shape weightsShapeStatic = sampleWeights.asOutput().shape(); int weightsRankStatic = weightsShapeStatic.numDimensions(); - Shape valuesShapeStatic = values.shape(); + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Shape valuesShapeStatic = values.asOutput().shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); - // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { - if (!weightsShapeStatic.isUnknown() - && !valuesShapeStatic.isUnknown() - && !weightsShapeStatic.hasUnknownDimension() - && !valuesShapeStatic.hasUnknownDimension()) { + if (weightsRankStatic != -1 && valuesRankStatic != -1) { if (weightsRankStatic == 0) { - return tf.withSubScope("staticScalarCheckSuccess") + return tf.withSubScope("static_scalar_check_success") .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } if (weightsRankStatic != valuesRankStatic) { - throw new NotBroadcastableException( + throw new IllegalArgumentException( String.format( "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", - ASSERT_BROADCAST_ERROR_PREFIX, + ASSERT_BROADCASTABLE_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, valuesShapeStatic.toString(), @@ -95,51 +79,39 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) - && weightsShapeStatic.size(i) != 1) { - throw new NotBroadcastableException( + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", - ASSERT_BROADCAST_ERROR_PREFIX, + ASSERT_BROADCASTABLE_ERROR_PREFIX, i, valuesShapeStatic.toString(), weightsShapeStatic.toString())); } } - return tf.withSubScope("staticDimsCheckSuccess") + return tf.withSubScope("static_dims_check_success") .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } // Dynamic checks. - Operand weightsShape = tf.shape(sampleWeights); - Operand weightsRank = tf.rank(sampleWeights); - Operand valuesShape = tf.shape(values); - Operand valuesRank = tf.rank(values); - - Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); + Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( - tf.constant(ASSERT_BROADCAST_ERROR_PREFIX), + tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), tf.constant("weights.shape="), weightsShape, tf.constant("values.shape="), valuesShape, - tf.constant("isScalar="), - isScalar); - - // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a - // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. - Operand reshapedWeights = - tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); - weightsShape = tf.shape(reshapedWeights); - weightsRank = tf.rank(reshapedWeights); + tf.constant("is_scalar="), + is_scalar); - Operand validNonscalar = - canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); + Operand isValidShape = + tf.select( + is_scalar, + is_scalar, + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); - Operand isValidShape = tf.select(isScalar, isScalar, validNonscalar); - - return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); + return tf.assertThat(isValidShape, data); } /** @@ -153,15 +125,15 @@ public static Op assertBroadcastable( * @param the data type for the operands * @return a boolean operand to determine if the Shape is scalar or not. */ - private static Operand canBroadcastNonscalarShapes( + private static Operand hasValidNonscalarShape( Ops tf, Operand weightsRank, Operand weightsShape, Operand valuesRank, Operand valuesShape) { - tf = tf.withSubScope("canBroadcastNonscalarShapes"); + tf = tf.withSubScope("has_valid_nonscalar_shape"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); - return tf.select(isSameRank, canBroadcastDims(tf, weightsShape, valuesShape), isSameRank); + return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); } /** @@ -169,180 +141,14 @@ private static Operand canBroadcastNonscalarShapes( * * @param tf the TensorFlow Ops * @param weightsShape the operand for the shape of the sample weights - * @param valuesShape the operand for the shape of the values + * @param valuesShape the operand for the shape of the sample weights * @param the data type for the operands * @return a boolean operand to determine if the shapes have valid dimensions or not. */ - private static Operand canBroadcastDims( + private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { - tf = tf.withSubScope("canBroadcastDims"); - Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); - Operand validDims = - tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); - Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); - - Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); - Operand numInvalidDims = tf.size(diffResult); - return tf.math.equal(tf.constant(0), numInvalidDims); + tf = tf.withSubScope("has_invalid_dims"); + Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); + return tf.math.equal(CastHelper.cast(tf, tf.constant(0), diff.type()), diff); } - - /** - * Broadcast weights to the same shape as values. - * - * @param tf the TensorFlow ops - * @param weights Operand whose shape is broadcastable to values. - * @param values Operand of any shape - * @param the type of Operands - * @return weights broadcast to values shape - */ - public static Operand broadcastWeights( - Ops tf, Operand weights, Operand values) { - - Shape weightsShape = weights.shape(); - Shape valuesShape = values.shape(); - - if (!weightsShape.hasUnknownDimension() - && !valuesShape.hasUnknownDimension() - && weightsShape.isCompatibleWith(valuesShape)) { - return weights; - } - - Ops ctf = - tf.withSubScope("broadcastWeights") - .withControlDependencies( - Collections.singletonList(assertBroadcastable(tf, weights, tf.onesLike(values)))); - return ctf.math.mul(weights, tf.onesLike(values)); - } - - // aliases for mean - - /** - * Calculate the mean of the operand, along all axes and keepDims is false - * - * - * @param tf the TensorFlow Ops - * @param x the Operand used to calculate the mean - * @param the type of the Operand. - * @return the mean of the operand - */ - public static Operand mean(Ops tf, Operand x) { - return mean(tf, x, null, false); - } - - /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false - * - * @param tf the TensorFlow Ops - * @param x the Operand used to calculate the mean - * @param axes Axes to compute the mean. - * @param the type of the Operand. - * @param the type of the axes. - * @return the mean of the operand, along the specified axes. - */ - public static Operand mean( - Ops tf, Operand x, Operand axes) { - return mean(tf, x, axes, false); - } - - /** - * Calculates the mean of the operand, along all axes. - * - * @param tf the TensorFlow Ops - * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained - * with length 1. - * @param the type of the operand - * @return the mean of elements of x. - */ - public static Operand mean( - Ops tf, Operand x, boolean keepDims) { - return mean(tf, x, null, keepDims); - } - - - - /** - * Calculates the mean of the operand, alongside the specified axes. - * - * @param tf the TensorFlow Ops - * @param x the Operand used to calculate the mean - * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the Operand - * @param the data type of the axes - * @return the mean of elements of x. - */ - - public static Operand mean( - Ops tf, Operand x, Operand axes, boolean keepDims) { - if (axes == null) { - axes = (Operand) allAxes(tf, x); - } - return tf.math.mean(x, axes, Mean.keepDims(keepDims)); - } - - /** - * Calculate the mean of the operand, along all axes and keepDims is false - * - * - * @param tf the TensorFlow Ops - * @param x the Operand used to calculate the mean - * @return the mean of the operand containing floating point numbers - */ - public static Operand booleanMean(Ops tf, Operand x) { - return booleanMean(tf, x, null, false); - } - - /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false - * - * @param tf the TensorFlow Ops - * @param x the Operand used to calculate the mean - * @param axes Axes to compute the mean. - * @param the type of the axes. - * @return the mean of the operand, along the specified axes, containing floating point numbers - */ - public static Operand booleanMean( - Ops tf, Operand x,Operand axes) { - return booleanMean(tf, x, axes, false); - } - - /** - * Calculates the mean of the boolean operand, alongside all axes. - * - * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes - * @return the mean of elements of x containing floating point numbers - */ - public static Operand booleanMean( - Ops tf, Operand x, boolean keepDims) { - return booleanMean(tf, x, null, keepDims); - } - - /** - * Calculates the mean of the boolean operand, alongside the specified axes. - * - * @param x the boolean Operand used to calculate the mean - * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes - * @return the mean of elements of x containing floating point numbers - */ - public static Operand booleanMean( - Ops tf, Operand x, Operand axes, boolean keepDims) { - Operand xf = cast(tf, x, TFloat64.class); - return mean(tf, xf, axes, keepDims); - } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 8e48cb4e573..d2c8b2dec93 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; @@ -27,13 +28,13 @@ import org.tensorflow.types.family.TNumber; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** * Encapsulates metrics that perform a reduce operation on the metric values. * * @param The data type for the metric values - * @param The data type for the metric result */ public abstract class Reduce extends Metric { public static final String TOTAL = "total"; @@ -41,14 +42,13 @@ public abstract class Reduce extends Metri protected final MetricReduction reduction; private final String totalName; private final String countName; - - private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values. - * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + /** the variable that holds the count of the metric values */ protected Variable count; + protected boolean initialized; + /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} * @@ -56,55 +56,44 @@ public abstract class Reduce extends Metri * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param resultType the type for the variables and result */ - protected Reduce(Ops tf, String name, long seed, Class resultType) { - this(tf, name, MetricReduction.SUM, seed, resultType); + protected Reduce(Ops tf, String name, long seed, Class type) { + this(tf, name, seed, MetricReduction.SUM, type); } /** * @param tf The TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. - * @param reduction The type of metric reduction to apply * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param resultType the type for the variables and result + * @param reduction The type of metric reduction to apply */ - protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class resultType) { - super(tf, name, seed); + protected Reduce(Ops tf, String name, long seed, MetricReduction reduction, Class type) { + super(tf, name, seed, type); this.reduction = reduction; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); - this.resultType = resultType; setupVars(); } - /** Initializes the Variables */ + /** initialize the Variables */ + @SuppressWarnings("unchecked") private void setupVars() { + Zeros fZeros = new Zeros<>(getTF()); + total = (Variable) getVariable(totalName); if (total == null) { - total = getTF().withName(totalName).variable(Shape.scalar(), resultType); + total = getTF().withSubScope(totalName).variable(Shape.scalar(), getType()); + addVariable(totalName, total, fZeros); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { + count = (Variable) getVariable(countName); if (count == null) { - count = getTF().withName(countName).variable(Shape.scalar(), resultType); + count = getTF().withSubScope(countName).variable(Shape.scalar(), getType()); + addVariable(countName, count, fZeros); } } } - /** {@inheritDoc} */ - public Op resetStates() { - List controls = new ArrayList<>(); - if (total != null) { - controls.add( - getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); - } - if (count != null) { - controls.add( - getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); - } - return getTF().withControlDependencies(controls).noOp(); - } - /** * Updates the metric variables based on the inputs. At least one input arg required for * values, an optional additional input for the sampleWeights @@ -115,68 +104,52 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { - if (values == null) { - throw new IllegalArgumentException("values is required."); - } + if (values == null) throw new IllegalArgumentException("values is required."); List updateOperations = new ArrayList<>(); // cast everything to match the variables - Operand lSampleWeights = null; - Operand lValues = values; + Operand tValues = CastHelper.cast(getTF(), values, getType()); + Operand tSampleWeights = sampleWeights; if (sampleWeights != null) { - lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); - lValues = tuple.getTarget(); - lSampleWeights = tuple.getSampleWeights(); - try { - lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); - } catch (IllegalArgumentException ex) { - // if we get here we have static shapes with either - // different ranks or different dimension sizes. - // first, reduce the values down to the rank of the samples - int valuesRank = lValues.shape().numDimensions(); - int weightsRank = lSampleWeights.shape().numDimensions(); - int numAxes = Math.min(0, valuesRank - weightsRank); - if (numAxes - > 0) { // values rank is greater than weights rank, reduce values to weights rank. - int[] axes = new int[numAxes]; - for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; - if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); - } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); - } - } - } - lValues = getTF().math.mul(lValues, lSampleWeights); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, sampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); + Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); + tValues = + getTF() + .withSubScope("broadcastWeightsCheck") + .withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) + .math + .mul(tValues, tSampleWeights); } - Operand weightedValueSum = - getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand valueSum = getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + getTF().assignAdd(total, CastHelper.cast(getTF(), valueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + CastHelper.cast( + getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); break; case WEIGHTED_MEAN: - if (lSampleWeights == null) { + if (tSampleWeights == null) { numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + CastHelper.cast( + getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); } else { numValues = CastHelper.cast( getTF(), getTF() - .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), - resultType); + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), + getType()); } break; default: @@ -193,16 +166,16 @@ public List updateStateList(Operand values, Operand result() { + public Operand result(Ops rtf) { Operand fResult; switch (this.reduction) { case SUM: - fResult = getTF().identity(total); + fResult = rtf.identity(total); break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, getType())); break; default: throw new UnsupportedOperationException( @@ -228,13 +201,4 @@ public Variable getTotal() { public Variable getCount() { return count; } - - /** - * Gets the type for the variables - * - * @return the type for the variables - */ - public Class getResultType() { - return resultType; - } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 7ceedded018..1f07b9567cb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -23,7 +23,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; class BinaryCrossentropyTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -45,9 +44,9 @@ public void testUnweighted() { Variable total = instance.getTotal(); Variable count = instance.getCount(); Operand result = instance.result(); - session.evaluate(7.71247434F, total); + session.evaluate(7.666619F, total); session.evaluate(2, count); - session.evaluate(3.85623717F, result); + session.evaluate(3.833309F, result); } } @@ -58,9 +57,9 @@ public void testUnweightedLogits() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); - float[] trueArray = {1, 0, 1, 0, 1, 1}; + double[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Op op = instance.updateState(labels, logits, null); session.run(op); @@ -80,9 +79,9 @@ public void testWeighted() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); - int[] trueArray = {1, 0, 1, 0}; + float[] trueArray = {1, 0, 1, 0}; float[] predictionArray = {1, 1, 1, 0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(new float[] {1.5f, 2.f}); @@ -105,9 +104,9 @@ public void testWeightedLogits() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); - float[] trueArray = {1, 0, 1, 0, 1, 1}; + double[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {2, 2.5}); @@ -132,9 +131,9 @@ public void testLabelSmoothing() { new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); - float[] trueArray = {1, 0, 1}; + double[] trueArray = {1, 0, 1}; double[] logitsArray = {100., -100., -100.}; - Operand labels = tf.constant(trueArray); + Operand labels = tf.constant(trueArray); Operand logits = tf.constant(logitsArray); Op op = instance.updateState(labels, logits, null); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index a9721ef2f8f..848e2051af3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -79,7 +79,7 @@ public void testWeighted() { public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - int axis = 1; + int[] axis = new int[] {1}; CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 28020c0fa1c..bf98ec4eba4 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -69,7 +69,7 @@ public void testWeighted() { Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(new double[][] {{1.2}, {3.4}}); + Operand sampleWeight = tf.constant(new double[] {1.2, 3.4}); Op op = instance.updateState(labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..4a0cdefe492 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class SparseTopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseTopKCategoricalAccuracy instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testCorrectness", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new double[] {2, 1}); + Operand predictions = + tf.constant(new double[][] {{0.1, 0.9, 0.8}, {0.05, 0.95, 0}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testCorrectness", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + predictions = + tf.constant( + new double[][] { + {0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4}, + {0.05, 0.95, 0, 0, 0, 0, 0} + }); + instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testCorrectness", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseTopKCategoricalAccuracy instance = + new SparseTopKCategoricalAccuracy<>( + tf, "SparseTopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new int[] {1, 0, 2}); + Operand predictions = + tf.constant( + new double[][] { + {0, 0.9, 0.1}, + {0, 0.9, 0.1}, + {0, 0.9, 0.1} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..52ccde29196 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class TopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); + Operand predictions = + tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + labels = + tf.constant( + new float[][] { + {0, 0, 1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0} + }); + predictions = + tf.constant( + new double[][] { + {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, + {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} + }); + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = + tf.constant( + new double[][] { + {1, 0, 2}, + {1, 0, 0}, + {0, 0, 1} + }); + Operand predictions = + tf.constant( + new double[][] { + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} From 141ebd523ad0d299cb35c55fb830a3876405c08a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 31 Dec 2020 20:03:32 -0500 Subject: [PATCH 29/56] Initial checkin and sync with master --- .../framework/metrics/BinaryCrossentropy.java | 2 +- .../metrics/CategoricalCrossentropy.java | 2 +- .../framework/metrics/CosineSimilarity.java | 16 ++- .../tensorflow/framework/metrics/Mean.java | 3 +- .../tensorflow/framework/metrics/Metric.java | 42 ++----- .../tensorflow/framework/metrics/Metrics.java | 19 ++-- .../SparseTopKCategoricalAccuracy.java | 65 ----------- .../metrics/TopKCategoricalAccuracy.java | 63 ----------- .../metrics/impl/MeanMetricWrapper.java | 36 ++---- .../metrics/impl/MetricVariable.java | 13 ++- .../framework/metrics/impl/MetricsHelper.java | 100 +++++++++++++++-- .../framework/metrics/impl/Reduce.java | 88 +++++++++------ .../metrics/BinaryCrossentropyTest.java | 17 +-- .../metrics/CosineSimilarityTest.java | 2 +- .../framework/metrics/KLDivergenceTest.java | 2 +- .../SparseTopKCategoricalAccuracyTest.java | 96 ---------------- .../metrics/TopKCategoricalAccuracyTest.java | 103 ------------------ 17 files changed, 211 insertions(+), 458 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index d13d20bfdee..41a5533b5d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -28,7 +28,7 @@ * @param The data type for the metric result */ public class BinaryCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossInterface { private final boolean fromLogits; private final float labelSmoothing; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index cf9ecd0858a..79481f608a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -34,7 +34,7 @@ public class CategoricalCrossentropy private final boolean fromLogits; private final float labelSmoothing; - private int axis; + private final int axis; /** * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 61802572c7b..4a5214aea8d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -27,9 +27,9 @@ // is using the dot product proportional to the product of their magnitudes. // While the 2 concepts are similar, they are different. // Should we rename this metric to CosineProximity? -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity extends MeanMetricWrapper implements LossInterface { - public static final int[] DEFAULT_AXIS = {-1}; + public static final int DEFAULT_AXIS = -1; private final int[] axis; /** @@ -44,6 +44,18 @@ public CosineSimilarity(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_AXIS, seed, type); } + /** + * Creates a CosineSimilarity metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { + this(tf, name, new int[] {axis}, seed, type); + } /** * Creates a CosineSimilarity metric * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index 08d1083dd05..c68a70902a7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -33,8 +33,9 @@ public class Mean extends Reduce { * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the result. */ protected Mean(Ops tf, String name, long seed, Class type) { - super(tf, name, seed, MetricReduction.WEIGHTED_MEAN, type); + super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 62ec5439269..28a2ae0fa94 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -1,20 +1,3 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.framework.metrics; - /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ +package org.tensorflow.framework.metrics; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; @@ -74,17 +58,15 @@ public abstract class Metric { /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ private final String name; - private final Class type; - /** - * Creates a Metric with a name of {@link Class#getSimpleName()} } + * Creates a Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - protected Metric(Ops tf, long seed, Class type) { - this(tf, null, seed, type); + protected Metric(Ops tf, long seed) { + this(tf, null, seed); } /** @@ -95,13 +77,12 @@ protected Metric(Ops tf, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - protected Metric(Ops tf, String name, long seed, Class type) { + protected Metric(Ops tf, String name, long seed) { if (!tf.scope().env().isGraph()) throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); this.tf = tf.withSubScope(this.name); - this.type = type; } /** @@ -113,8 +94,8 @@ protected Metric(Ops tf, String name, long seed, Class type) { * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + @SuppressWarnings({"unchecked","unused"}) + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -142,7 +123,7 @@ public List updateStateList( * @param sampleWeights sample weights to be applied to values, may be null. * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -153,6 +134,7 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for the sample weights * @return the Operation to update the metric state */ public final Op updateState( @@ -206,7 +188,7 @@ protected void addVariable( // TODO option 2 would be to keep track of tf.scope().env() and if it changes, clear to old Map. Map> variables = variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); - variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed)); + variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed, variable.type())); } /** @@ -313,8 +295,4 @@ public Ops getTF() { public String getName() { return name; } - - public Class getType() { - return type; - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index f4282bfd0a9..c3c44ef6134 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -40,7 +40,7 @@ public class Metrics { * {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); * Operand<TFloat32> m = Metrics.topKCategoricalAccuracy( * labels, predictions, 3) - * //m.asOutput().shape().toString == "[2]" + * //m.shape().toString == "[2]" * * * @param tf the TensorFlow Ops. @@ -71,7 +71,7 @@ public static Operand topKCategoricalA * {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}}); * Operand<TFloat32> m = Metrics.topKCategoricalAccuracy( * labels, predictions, 3) - * //m.asOutput().shape().toString == "[2]" + * //m.shape().toString == "[2]" * * * @param tf the TensorFlow Ops. @@ -90,8 +90,8 @@ public static Operand sparseTopKCatego tLabels = CastHelper.cast(tf, labels, predictions.type()); else tLabels = (Operand) labels; - int predictionsRank = predictions.asOutput().shape().numDimensions(); - int labelsRank = tLabels.asOutput().shape().numDimensions(); + int predictionsRank = predictions.shape().numDimensions(); + int labelsRank = tLabels.shape().numDimensions(); Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { @@ -152,9 +152,10 @@ public static Operand cosineProximity( * @param The data type for x. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python + // TODO this was tf.math.l2_normalize in TF Python, does it belong here? - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { + public static Operand l2Normalize( + Ops tf, Operand x, int[] axes) { return l2Normalize(tf, x, axes, L2_NORM_EPSILON); } @@ -178,15 +179,15 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param The data type for the values. * @return the normalized values of x. */ - // TODO this was tf.math.l2_normalize in TF Python + // TODO this was tf.math.l2_normalize in TF Python, does it belong here? public static Operand l2Normalize( Ops tf, Operand x, int[] axes, float epsilon) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); Operand y = tf.math.rsqrt( - tf.math.maximum( - squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); + tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); return tf.math.mul(x, y); } + } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java deleted file mode 100644 index 1412465bd89..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossInterface; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; -import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; - -/** Computes the poisson loss metric between labels and predictions. */ -public class SparseTopKCategoricalAccuracy - extends MeanMetricWrapper implements LossInterface { - public static final int DEFAULT_K = 5; - /** Number of top elements to look at for computing accuracy. */ - private final int k; - - /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of - * top elements to look at for computing accuracy. - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the date type for the result - */ - public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_K, seed, type); - } - - /** - * Creates a TopKCategoricalAccuracy metric - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param k Number of top elements to look at for computing accuracy. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the date type for the result - */ - public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { - super(tf, name, seed, type); - this.k = k; - setLoss(this); - } - - /** {@inheritDoc} */ - @Override - public Operand call(Operand labels, Operand predictions) { - return Metrics.sparseTopKCategoricalAccuracy(getTF(), labels, predictions, k); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java deleted file mode 100644 index 3198ab0ee04..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossInterface; -import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; -import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; - -/** Computes the poisson loss metric between labels and predictions. */ -public class TopKCategoricalAccuracy - extends MeanMetricWrapper implements LossInterface { - public static final int DEFAULT_K = 5; - /** Number of top elements to look at for computing accuracy. */ - private final int k; - - /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of - * top elements to look at for computing accuracy. - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - */ - public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { - this(tf, name, DEFAULT_K, seed, type); - } - - /** - * Creates a TopKCategoricalAccuracy metric - * - * @param tf the TensorFlow Ops - * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param k Number of top elements to look at for computing accuracy. - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - */ - public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { - super(tf, name, seed, type); - this.k = k; - setLoss(this); - } - - /** {@inheritDoc} */ - @Override - public Operand call(Operand labels, Operand predictions) { - return Metrics.topKCategoricalAccuracy(getTF(), labels, predictions, k); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 5e0023c4dbe..77566e5c400 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -32,7 +32,8 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the loss. + * @param the data type for the predictions. + * @param The data type for the metric result */ public class MeanMetricWrapper extends Mean { @@ -86,40 +87,17 @@ public void setLoss(LossInterface loss) { * @param the datatype of the predictions * @return a List of control operations that updates the Mean state variables. */ - public List updateLossStateList( + public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { if (labels == null || predictions == null) throw new IllegalArgumentException("missing required inputs for labels and predictions"); - Class type = predictions.type(); - Operand tPredicitons = CastHelper.cast(getTF(), predictions, getType()); + Operand tLabels = CastHelper.cast(getTF(), labels, getType()); + Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); - Operand losses = loss.call(labels, tPredicitons); - Operand uLossess = CastHelper.cast(getTF(), losses, type); - return super.updateStateList(uLossess, sampleWeights); - } + Operand losses = loss.call(tLabels, tPredictions); - /** - * Creates a Control Operation that updates the state of the mean metric by calculating the loss - * between the labels and predictions and then applying a weighted mean - * metric across the multiple iterations. - * - * @param labels the truth values or labels - * @param predictions the predictions - * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is - * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor - * of size [batch_size], then the total loss for each sample of the batch is rescaled by the - * corresponding element in the sampleWeights vector. If the shape of sampleWeights is - * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of - * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss - * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @return a NoOp with control dependencies that update the state of the mean metric. - */ - public final Op updateLossState( - Operand labels, Operand predictions, Operand sampleWeights) { - List controlOps = updateLossStateList(labels, predictions, sampleWeights); - return getTF().withSubScope("updateState").withControlDependencies(controlOps).noOp(); + return super.updateStateList(CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index 78d7459697c..cb5e987b4cf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -45,8 +45,8 @@ public class MetricVariable { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. */ - public MetricVariable(Ops tf, Variable variable, long seed) { - this(tf, variable, null, seed); + public MetricVariable(Ops tf, Variable variable, long seed, Class type) { + this(tf, variable, null, seed, type); } /** * Creates a Metric Variable @@ -61,13 +61,14 @@ public MetricVariable(Ops tf, Variable variable, long seed) { * will always produce the same random tensor for a given shape and data type. */ @SuppressWarnings("unchecked") - public MetricVariable(Ops tf, Variable variable, Initializer initializer, long seed) { + public MetricVariable( + Ops tf, Variable variable, Initializer initializer, long seed, Class type) { this.tf = tf; this.variable = variable; - Class type = variable.type(); if (initializer == null) { if (TFloating.class.isAssignableFrom(type)) { + //noinspection RedundantCast this.initializer = (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); } else if (TIntegral.class.isAssignableFrom(type)) { @@ -76,7 +77,7 @@ public MetricVariable(Ops tf, Variable variable, Initializer initializer, throw new IllegalArgumentException( String.format( "An initializer for variable %s of type %s is required", - variable.toString(), type)); + variable.toString(), type.getSimpleName())); } } else { this.initializer = initializer; @@ -90,7 +91,7 @@ public MetricVariable(Ops tf, Variable variable, Initializer initializer, */ public Operand initialize() { initialized = true; - return initializer.call(tf.constant(variable.asOutput().shape()), variable.type()); + return initializer.call(tf.constant(variable.shape()), variable.type()); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 5395cccf4a7..042badbb615 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,24 +15,30 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * These are helper methods for Metrics and will be module private when Java modularity is applied * to TensorFlow Java. These methods should not be used outside of the metrics packages. */ public class MetricsHelper { - private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = + public static final float NEG_INF = -1e10f; + private static final String ASSERT_BROADCAST_ERROR_PREFIX = "weights can not be broadcast to values."; /** @@ -53,12 +59,12 @@ public static Op broadcastWeights( Operand weightsShape = tf.shape(sampleWeights); Operand weightsRank = tf.rank(sampleWeights); - Shape weightsShapeStatic = sampleWeights.asOutput().shape(); + Shape weightsShapeStatic = sampleWeights.shape(); int weightsRankStatic = weightsShapeStatic.numDimensions(); Operand valuesShape = tf.shape(values); Operand valuesRank = tf.rank(values); - Shape valuesShapeStatic = values.asOutput().shape(); + Shape valuesShapeStatic = values.shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); if (weightsRankStatic != -1 && valuesRankStatic != -1) { @@ -71,7 +77,7 @@ public static Op broadcastWeights( throw new IllegalArgumentException( String.format( "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, + ASSERT_BROADCAST_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, valuesShapeStatic.toString(), @@ -83,7 +89,7 @@ public static Op broadcastWeights( throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", - ASSERT_BROADCASTABLE_ERROR_PREFIX, + ASSERT_BROADCAST_ERROR_PREFIX, i, valuesShapeStatic.toString(), weightsShapeStatic.toString())); @@ -97,7 +103,7 @@ public static Op broadcastWeights( Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( - tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), + tf.constant(ASSERT_BROADCAST_ERROR_PREFIX), tf.constant("weights.shape="), weightsShape, tf.constant("values.shape="), @@ -111,7 +117,7 @@ public static Op broadcastWeights( is_scalar, hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); - return tf.assertThat(isValidShape, data); + return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); } /** @@ -149,6 +155,82 @@ private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("has_invalid_dims"); Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); - return tf.math.equal(CastHelper.cast(tf, tf.constant(0), diff.type()), diff); + return tf.math.equal(cast(tf, tf.constant(0), diff.asOutput().type()), diff); + } + + // alias for mean + + /** + * Calculate the mean of the operand, along all axes and keepDims is false + * + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param the type of the Operand. + * @return the mean of the operand + */ + public static Operand mean(Ops tf, Operand x) { + return mean(tf, x, null, false); + } + + /** + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axis Axes to compute the mean. + * @param the type of the Operand. + * @param the type of the axis. + * @return the mean of the operand, alongside the specified axis. + */ + public static Operand mean( + Ops tf, Operand x, Operand axis) { + return mean(tf, x, axis, false); + } + + /** + * Calculate the mean of the operand, along all axis. + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axis + * . If keepdims is true, the reduced dimensions are retained + * with length 1. + * @param the type of the operand + * @return the mean of elements of x. + */ + public static Operand mean(Ops tf, Operand x, boolean keepDims) { + return mean(tf, x, null, keepDims); + } + + /** + * Calculate the mean of the operand, alongside the specified axis. + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axis Axes to compute the mean. + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the Operand + * @param the data type of the axis + * @return the mean of elements of `x`. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Operand mean( + Ops tf, Operand x, Operand axis, boolean keepDims) { + // Cannot use generics here because xf may change from TBool to TFloat32 + Operand xf; + if (x.asOutput().type() == TBool.class) { + xf = tf.dtypes.cast(x, TFloat32.class); + } else { + xf = x; + } + if (axis == null) { + axis = allAxes(tf, xf); + } + return tf.math.mean(xf, axis, Mean.keepDims(keepDims)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index d2c8b2dec93..d3b7caa54cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -35,6 +35,7 @@ * Encapsulates metrics that perform a reduce operation on the metric values. * * @param The data type for the metric values + * @param The data type for the metric result */ public abstract class Reduce extends Metric { public static final String TOTAL = "total"; @@ -42,6 +43,8 @@ public abstract class Reduce extends Metri protected final MetricReduction reduction; private final String totalName; private final String countName; + + private final Class type; /** the variable that holds the total of the metric values */ protected Variable total; /** the variable that holds the count of the metric values */ @@ -56,23 +59,26 @@ public abstract class Reduce extends Metri * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ protected Reduce(Ops tf, String name, long seed, Class type) { - this(tf, name, seed, MetricReduction.SUM, type); + this(tf, name, MetricReduction.SUM, seed, type); } /** * @param tf The TensorFlow Ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param reduction The type of metric reduction to apply * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param reduction The type of metric reduction to apply + * @param type the type for the variables and result */ - protected Reduce(Ops tf, String name, long seed, MetricReduction reduction, Class type) { - super(tf, name, seed, type); + protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class type) { + super(tf, name, seed); this.reduction = reduction; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); + this.type = type; setupVars(); } /** initialize the Variables */ @@ -81,14 +87,14 @@ private void setupVars() { Zeros fZeros = new Zeros<>(getTF()); total = (Variable) getVariable(totalName); if (total == null) { - total = getTF().withSubScope(totalName).variable(Shape.scalar(), getType()); + total = getTF().withSubScope(totalName).variable(Shape.scalar(), type); addVariable(totalName, total, fZeros); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { count = (Variable) getVariable(countName); if (count == null) { - count = getTF().withSubScope(countName).variable(Shape.scalar(), getType()); + count = getTF().withSubScope(countName).variable(Shape.scalar(), type); addVariable(countName, count, fZeros); } } @@ -104,29 +110,48 @@ private void setupVars() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { if (values == null) throw new IllegalArgumentException("values is required."); List updateOperations = new ArrayList<>(); // cast everything to match the variables + Operand lSampleWeights = null; + Operand lValues = values; - Operand tValues = CastHelper.cast(getTF(), values, getType()); - Operand tSampleWeights = sampleWeights; if (sampleWeights != null) { - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, sampleWeights); - tValues = tuple.getTarget(); - tSampleWeights = tuple.getSampleWeights(); - Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); - tValues = - getTF() - .withSubScope("broadcastWeightsCheck") - .withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) - .math - .mul(tValues, tSampleWeights); + lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); + lValues = tuple.getTarget(); + lSampleWeights = tuple.getSampleWeights(); + // lSampleWeights = WeightsBroadcastOps.broadcastWeights(getTF(), lSampleWeights, lValues); + try { + + Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + lValues = + getTF() + .withSubScope("broadcastWeightsCheck") + .withControlDependencies(Collections.singletonList(broadcastWeightsCheck)) + .math + .mul(lValues, lSampleWeights); + } catch (IllegalArgumentException ex) { + System.out.println("Reduce: Fall back from broadcast"); + // reduce the values down to the rank of the samples + int nDim = lValues.shape().numDimensions(); + int wDim = lSampleWeights.shape().numDimensions(); + int numAxes = nDim - wDim; + int[] axes = new int[numAxes]; + for (int i = 0; i < numAxes; i++) axes[i] = i + wDim; + if (reduction == MetricReduction.SUM) { + lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + } else { + lValues = getTF().math.mean(lValues, getTF().constant(axes)); + } + lValues = getTF().math.mul(lValues, lSampleWeights); + } } - Operand valueSum = getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); + Operand valueSum = getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); Operand totalUpdate = getTF().assignAdd(total, CastHelper.cast(getTF(), valueSum, total.type())); updateOperations.add(totalUpdate); @@ -134,22 +159,18 @@ public List updateStateList(Operand values, Operand< if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast( - getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); break; case WEIGHTED_MEAN: - if (tSampleWeights == null) { - numValues = - CastHelper.cast( - getTF(), getTF().constant(tValues.asOutput().shape().size()), getType()); + if (lSampleWeights == null) { + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); } else { numValues = CastHelper.cast( getTF(), getTF() - .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), - getType()); + .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + type); } break; default: @@ -175,7 +196,7 @@ public Operand result(Ops rtf) { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, getType())); + fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, type)); break; default: throw new UnsupportedOperationException( @@ -201,4 +222,9 @@ public Variable getTotal() { public Variable getCount() { return count; } + + /** Gets the type for the variables */ + public Class getType() { + return type; + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 1f07b9567cb..0529026ce8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -23,6 +23,7 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; class BinaryCrossentropyTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -57,9 +58,9 @@ public void testUnweightedLogits() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); - double[] trueArray = {1, 0, 1, 0, 1, 1}; + float[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Op op = instance.updateState(labels, logits, null); session.run(op); @@ -79,9 +80,9 @@ public void testWeighted() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); - float[] trueArray = {1, 0, 1, 0}; + int[] trueArray = {1, 0, 1, 0}; float[] predictionArray = {1, 1, 1, 0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); Operand yPrediction = tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); Operand sampleWeight = tf.constant(new float[] {1.5f, 2.f}); @@ -104,9 +105,9 @@ public void testWeightedLogits() { BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); - double[] trueArray = {1, 0, 1, 0, 1, 1}; + float[] trueArray = {1, 0, 1, 0, 1, 1}; double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; - Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); Operand sampleWeight = tf.constant(new double[] {2, 2.5}); @@ -131,9 +132,9 @@ public void testLabelSmoothing() { new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); - double[] trueArray = {1, 0, 1}; + float[] trueArray = {1, 0, 1}; double[] logitsArray = {100., -100., -100.}; - Operand labels = tf.constant(trueArray); + Operand labels = tf.constant(trueArray); Operand logits = tf.constant(logitsArray); Op op = instance.updateState(labels, logits, null); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index 848e2051af3..a9721ef2f8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -79,7 +79,7 @@ public void testWeighted() { public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - int[] axis = new int[] {1}; + int axis = 1; CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index bf98ec4eba4..28020c0fa1c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -69,7 +69,7 @@ public void testWeighted() { Operand predictions = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); - Operand sampleWeight = tf.constant(new double[] {1.2, 3.4}); + Operand sampleWeight = tf.constant(new double[][] {{1.2}, {3.4}}); Op op = instance.updateState(labels, predictions, sampleWeight); session.run(op); Variable total = instance.getTotal(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java deleted file mode 100644 index 4a0cdefe492..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracyTest.java +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -class SparseTopKCategoricalAccuracyTest { - private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - - @Test - public void testCorrectness() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - SparseTopKCategoricalAccuracy instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testCorrectness", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - - Operand labels = tf.constant(new double[] {2, 1}); - Operand predictions = - tf.constant(new double[][] {{0.1, 0.9, 0.8}, {0.05, 0.95, 0}}); - - Op update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(1., instance.result()); - - // With `k` < 5. - instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testCorrectness", 1, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - - // With `k` > 5. - predictions = - tf.constant( - new double[][] { - {0.5, 0.9, 0.1, 0.7, 0.6, 0.5, 0.4}, - {0.05, 0.95, 0, 0, 0, 0, 0} - }); - instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testCorrectness", 6, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - } - } - - @Test - public void testWeighted() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - SparseTopKCategoricalAccuracy instance = - new SparseTopKCategoricalAccuracy<>( - tf, "SparseTopK_testWeighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - - Operand labels = tf.constant(new int[] {1, 0, 2}); - Operand predictions = - tf.constant( - new double[][] { - {0, 0.9, 0.1}, - {0, 0.9, 0.1}, - {0, 0.9, 0.1} - }); - - Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); - - Op update = instance.updateState(labels, predictions, sampleWeight); - session.run(update); - session.evaluate(1., instance.result()); - } - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java deleted file mode 100644 index 52ccde29196..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; - -class TopKCategoricalAccuracyTest { - private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - - @Test - public void testCorrectness() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - TopKCategoricalAccuracy instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); - Operand predictions = - tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); - - Op update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(1., instance.result()); - - // With `k` < 5. - instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - - // With `k` > 5. - labels = - tf.constant( - new float[][] { - {0, 0, 1, 0, 0, 0, 0}, - {0, 1, 0, 0, 0, 0, 0} - }); - predictions = - tf.constant( - new double[][] { - {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, - {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} - }); - instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); - session.run(instance.resetStates()); - update = instance.updateState(labels, predictions, null); - session.run(update); - session.evaluate(0.5, instance.result()); - } - } - - @Test - public void testWeighted() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - TopKCategoricalAccuracy instance = - new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); - session.run(instance.resetStates()); - - Operand labels = - tf.constant( - new double[][] { - {1, 0, 2}, - {1, 0, 0}, - {0, 0, 1} - }); - Operand predictions = - tf.constant( - new double[][] { - {0f, 0.9f, 0.1f}, - {0f, 0.9f, 0.1f}, - {0f, 0.9f, 0.1f} - }); - - Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); - - Op update = instance.updateState(labels, predictions, sampleWeight); - session.run(update); - session.evaluate(1., instance.result()); - } - } -} From 360a2dc0ba68b064bdd4f9f9823ae826b8cec379 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 1 Jan 2021 09:04:32 -0500 Subject: [PATCH 30/56] Initial checkin and sync with master --- .../tensorflow/framework/metrics/BinaryCrossentropyTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 0529026ce8f..7ceedded018 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -45,9 +45,9 @@ public void testUnweighted() { Variable total = instance.getTotal(); Variable count = instance.getCount(); Operand result = instance.result(); - session.evaluate(7.666619F, total); + session.evaluate(7.71247434F, total); session.evaluate(2, count); - session.evaluate(3.833309F, result); + session.evaluate(3.85623717F, result); } } From 8ec039076d1f1e77e8f520800c4a3a24e16d891b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 1 Jan 2021 09:46:18 -0500 Subject: [PATCH 31/56] JavaDoc cleanup --- .../framework/metrics/BinaryCrossentropy.java | 4 +- .../metrics/CategoricalCrossentropy.java | 11 +++- .../framework/metrics/CategoricalHinge.java | 8 ++- .../framework/metrics/CosineSimilarity.java | 19 +++--- .../tensorflow/framework/metrics/Hinge.java | 8 ++- .../framework/metrics/KLDivergence.java | 9 ++- .../framework/metrics/LogCoshError.java | 8 ++- .../tensorflow/framework/metrics/Mean.java | 4 +- .../framework/metrics/MeanAbsoluteError.java | 8 ++- .../metrics/MeanAbsolutePercentageError.java | 8 ++- .../framework/metrics/MeanSquaredError.java | 8 ++- .../metrics/MeanSquaredLogarithmicError.java | 8 ++- .../tensorflow/framework/metrics/Metric.java | 14 ----- .../tensorflow/framework/metrics/Metrics.java | 59 +------------------ .../tensorflow/framework/metrics/Poisson.java | 8 ++- .../SparseCategoricalCrossentropy.java | 9 ++- .../framework/metrics/SquaredHinge.java | 8 ++- .../metrics/impl/MeanMetricWrapper.java | 9 +-- .../metrics/impl/MetricVariable.java | 4 +- 19 files changed, 110 insertions(+), 104 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 41a5533b5d1..2372293d0d3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -22,7 +22,7 @@ import org.tensorflow.types.family.TNumber; /** - * Computes the binary cross-entropy loss between true labels and predicted labels. + * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * * @param the data type for the predictions. * @param The data type for the metric result @@ -48,7 +48,7 @@ public class BinaryCrossentropy * correspond to heavier smoothing. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the data type for the variables + * @param type the type for the variables and result */ public BinaryCrossentropy( Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 79481f608a1..6bfd471401b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -22,12 +22,16 @@ import org.tensorflow.types.family.TNumber; /** - * Computes the categorical cross-entropy loss between true labels and predicted labels. + * A Metric that computes the categorical cross-entropy loss between true labels and predicted + * labels. * *

The loss function calculates the loss between the labels and predictions * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the @@ -47,6 +47,7 @@ public class MeanMetricWrapper extends Mea * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); @@ -95,9 +96,9 @@ public List updateStateList( Operand tLabels = CastHelper.cast(getTF(), labels, getType()); Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); - Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList(CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList( + CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index cb5e987b4cf..786d5db1261 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -30,7 +30,6 @@ * * @param the data type of the variable */ -// TODO handle distributed variables with VariableAggregation and VariableSynchronization public class MetricVariable { private final Variable variable; private final Initializer initializer; @@ -44,10 +43,12 @@ public class MetricVariable { * @param variable the variable * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variable */ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { this(tf, variable, null, seed, type); } + /** * Creates a Metric Variable * @@ -59,6 +60,7 @@ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variable */ @SuppressWarnings("unchecked") public MetricVariable( From 14de4461d39eb45daaad2f854a11fc30b24c4f9e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 3 Jan 2021 12:07:45 -0500 Subject: [PATCH 32/56] Javadoc fixes --- .../tensorflow/framework/metrics/CategoricalCrossentropy.java | 1 + .../org/tensorflow/framework/metrics/CosineSimilarity.java | 2 ++ .../main/java/org/tensorflow/framework/metrics/Metric.java | 2 ++ .../main/java/org/tensorflow/framework/metrics/Metrics.java | 4 ++-- .../java/org/tensorflow/framework/metrics/impl/Reduce.java | 4 +++- 5 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 6bfd471401b..72e15f1b22b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -79,6 +79,7 @@ public CategoricalCrossentropy( * channels_first. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CategoricalCrossentropy( Ops tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 9ceccf7fc13..5bd0c53b416 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -53,6 +53,7 @@ public CosineSimilarity(Ops tf, String name, long seed, Class type) { * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { this(tf, name, new int[] {axis}, seed, type); @@ -65,6 +66,7 @@ public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) * @param axis The dimension along which the cosine similarity is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) { super(tf, name, seed, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 89e5436ed0a..378d026e69c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -164,11 +164,13 @@ public final Operand callOnce( /** * Adds a variable to collect metric values * + * @param varName the name for the variable * @param variable the variable * @param initializer the initializer for the variable, if null, then the default for floating * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} + * @param the date type for the variable */ protected void addVariable( String varName, Variable variable, Initializer initializer) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 8a8ddf3694c..e31cb54a4d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -121,8 +121,8 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param tf The TensorFlow ops * @param x The operand to normalize * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if - * `norm < sqrt(epsilon)`. + * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the divisor if + * norm < sqrt(epsilon). * @param The data type for the values. * @return the normalized values of x. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index d3b7caa54cc..c8499bc1599 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -223,7 +223,9 @@ public Variable getCount() { return count; } - /** Gets the type for the variables */ + /** Gets the type for the variables + * @return the type for the variables + */ public Class getType() { return type; } From d0a7fd67074b37af475e559047ea4e03cd707a3b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 5 Jan 2021 12:32:34 -0500 Subject: [PATCH 33/56] Change LossInterface to LossMetric. Fix JavaDoc, modify one line code block to include braces. --- .../framework/metrics/BinaryCrossentropy.java | 10 +++--- .../metrics/CategoricalCrossentropy.java | 8 ++--- .../framework/metrics/CategoricalHinge.java | 4 +-- .../framework/metrics/CosineSimilarity.java | 4 +-- .../tensorflow/framework/metrics/Hinge.java | 4 +-- .../framework/metrics/KLDivergence.java | 4 +-- .../framework/metrics/LogCoshError.java | 4 +-- .../framework/metrics/MeanAbsoluteError.java | 4 +-- .../metrics/MeanAbsolutePercentageError.java | 4 +-- .../framework/metrics/MeanSquaredError.java | 4 +-- .../metrics/MeanSquaredLogarithmicError.java | 4 +-- .../tensorflow/framework/metrics/Metric.java | 10 +++--- .../tensorflow/framework/metrics/Metrics.java | 6 +--- .../tensorflow/framework/metrics/Poisson.java | 4 +-- .../SparseCategoricalCrossentropy.java | 4 +-- .../framework/metrics/SquaredHinge.java | 4 +-- .../framework/metrics/impl/LossInterface.java | 36 ------------------- .../metrics/impl/MeanMetricWrapper.java | 9 ++--- .../metrics/impl/MetricVariable.java | 5 +-- .../framework/metrics/impl/MetricsHelper.java | 10 +++--- .../framework/metrics/impl/Reduce.java | 5 +-- 21 files changed, 55 insertions(+), 92 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 2372293d0d3..c339b977007 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -24,11 +24,14 @@ /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * + *

This is the crossentropy metric class to be used when there are only two label classes (0 and + * 1). + * * @param the data type for the predictions. * @param The data type for the metric result */ public class BinaryCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -36,9 +39,6 @@ public class BinaryCrossentropy /** * Creates a BinaryCrossentropy metric * - *

This is the crossentropy metric class to be used when there are only two label classes (0 - * and 1). - * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 72e15f1b22b..7b8cf0054a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -34,14 +34,14 @@ * @param The data type for the metric result */ public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; private final int axis; /** - * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the * labels and predictions. * *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. @@ -63,7 +63,7 @@ public CategoricalCrossentropy( } /** - * Creates a CategoricalCrossentropy metric that Computes the crossentropy metric between the + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the * labels and predictions. * * @param tf the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 21f19d88ade..2741a36edb6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result */ public class CategoricalHinge extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a CategoricalHinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 5bd0c53b416..458de092bec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,7 +15,7 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -27,7 +27,7 @@ * @param The data type for the metric result. */ public class CosineSimilarity extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index b276f0b9426..baf9ad8ab7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class Hinge extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Hinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index a3cbc6f16e6..efcbbcbb7f0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -29,7 +29,7 @@ * @param The data type for the metric result. */ public class KLDivergence extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a KLDivergence metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index d6fe903f5a1..3df8505d54b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -29,7 +29,7 @@ * @param The data type for the metric result. */ public class LogCoshError extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a LogCoshError metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 79da80ef191..e27676932ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanAbsoluteError extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 558c194074f..84fa9b627b2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 10704d14bd4..c7edd6ebe93 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanSquaredError extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 585fc312e5a..199b6e0e114 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 378d026e69c..c816b1a98d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -65,8 +65,9 @@ protected Metric(Ops tf, long seed) { * will always produce the same random tensor for a given shape and data type. */ protected Metric(Ops tf, String name, long seed) { - if (!tf.scope().env().isGraph()) + if (!tf.scope().env().isGraph()) { throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); + } this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); this.tf = tf.withSubScope(this.name); @@ -81,7 +82,7 @@ protected Metric(Ops tf, String name, long seed) { * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked","unused"}) + @SuppressWarnings({"unchecked", "unused"}) public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -97,7 +98,7 @@ public List updateStateList(Operand values, Operand sampleWeights) { * @param the data type for the sample weights * @return a List of Operations to update the metric state */ - @SuppressWarnings({"unchecked","unused"}) + @SuppressWarnings({"unchecked", "unused"}) public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { return Collections.EMPTY_LIST; @@ -154,8 +155,7 @@ public Operand result() { * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return result(ltf); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index e31cb54a4d1..b8e79efa450 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -69,13 +69,9 @@ public static Operand topKCategoricalA * @param the data type for the predictions and result * @return Cosine similarity value. */ - @SuppressWarnings("unchecked") public static Operand cosineProximity( Ops tf, Operand labels, Operand predictions, int[] axis) { - Operand labelsNorm; - if (labels.type() != predictions.type()) - labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - else labelsNorm = (Operand) labels; + Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); labelsNorm = l2Normalize(tf, labelsNorm, axis); Operand predictionsNorm = l2Normalize(tf, predictions, axis); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 07ab129eb08..75a2031fbb5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class Poisson extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a Poisson metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index c2f916217e4..3fde8b2ecf6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -29,7 +29,7 @@ * @param The data type for the metric result. */ public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossInterface { + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final int axes; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index d8c7aa097fe..430dbbcc229 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.losses.Losses; -import org.tensorflow.framework.metrics.impl.LossInterface; +import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -28,7 +28,7 @@ * @param The data type for the metric result. */ public class SquaredHinge extends MeanMetricWrapper - implements LossInterface { + implements LossMetric { /** * Creates a SquaredHinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java deleted file mode 100644 index aadc211c3c4..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossInterface.java +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import org.tensorflow.Operand; -import org.tensorflow.types.family.TNumber; - -/** - * Interface for Metrics that wrap Loss functions. - * - * @param The data type of the predictions. - */ -public interface LossInterface { - - /** - * Calculates the weighted loss between labels and predictions - * - * @param labels the truth values or labels - * @param predictions the predictions - * @param The data type of the labels. - * @return the loss - */ - Operand call(Operand labels, Operand predictions); -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 5894b24c4cd..cd17e2a9de4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -38,7 +38,7 @@ public class MeanMetricWrapper extends Mean { /** The loss function interface */ - protected LossInterface loss; + protected LossMetric loss; /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN} @@ -58,7 +58,7 @@ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { * * @return the loss function. */ - public LossInterface getLoss() { + public LossMetric getLoss() { return loss; } @@ -67,7 +67,7 @@ public LossInterface getLoss() { * * @param loss the loss function. */ - public void setLoss(LossInterface loss) { + protected void setLoss(LossMetric loss) { this.loss = loss; } @@ -90,8 +90,9 @@ public void setLoss(LossInterface loss) { */ public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { - if (labels == null || predictions == null) + if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); + } Operand tLabels = CastHelper.cast(getTF(), labels, getType()); Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index 786d5db1261..c5c5dbb2ab2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -61,6 +61,7 @@ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variable + * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer is null */ @SuppressWarnings("unchecked") public MetricVariable( @@ -78,8 +79,8 @@ public MetricVariable( } else { throw new IllegalArgumentException( String.format( - "An initializer for variable %s of type %s is required", - variable.toString(), type.getSimpleName())); + "Type %s is not a supported for metric variables", + type.getSimpleName())); } } else { this.initializer = initializer; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 042badbb615..9699ccd323c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -42,16 +42,16 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeight can be broadcast to values + * Asserts that the sampleWeights can be broadcast to values * * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation raising InvalidArgumentError if sampleWeight - * has incorrect shape. no_op if static checks determine - * sampleWeight has correct shape. + * @return Operation with control dependencies to ensure sampleWeight + * can be broadcast to values * @param the type of Operand - * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. + * @throws IllegalArgumentException If static checks determine sampleWeights has an + * incorrect shape that prohibit broadcasting to to values */ @SuppressWarnings("unchecked") public static Op broadcastWeights( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index c8499bc1599..3ec6540779c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -112,7 +112,9 @@ private void setupVars() { @Override public List updateStateList(Operand values, Operand sampleWeights) { - if (values == null) throw new IllegalArgumentException("values is required."); + if (values == null) { + throw new IllegalArgumentException("values is required."); + } List updateOperations = new ArrayList<>(); // cast everything to match the variables Operand lSampleWeights = null; @@ -124,7 +126,6 @@ public List updateStateList(Operand values, Operand sampleWeights) { LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); lValues = tuple.getTarget(); lSampleWeights = tuple.getSampleWeights(); - // lSampleWeights = WeightsBroadcastOps.broadcastWeights(getTF(), lSampleWeights, lValues); try { Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); From feb430cb2b765b48e62f2df2713ae1c1e3387d5c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 7 Jan 2021 17:15:30 -0500 Subject: [PATCH 34/56] Removed hashmap for variables, they are not needed as the variables only live within a single instance of a Metric. --- .../tensorflow/framework/metrics/Metric.java | 121 +++--------------- .../framework/metrics/impl/Reduce.java | 37 +++--- 2 files changed, 37 insertions(+), 121 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index c816b1a98d0..a6f2cf0f26d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -14,17 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics; -import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; -import org.tensorflow.framework.initializers.Initializer; -import org.tensorflow.framework.metrics.impl.MetricVariable; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.*; -import java.util.stream.Collectors; +import java.util.Collections; +import java.util.List; /** * Base class for Metrics @@ -34,12 +30,8 @@ */ public abstract class Metric { - /** variables are stored by ExecutionEnvironment, and then by an identifier name */ - protected static Map>> - variableMap = new WeakHashMap<>(); /** The TensorFlow Ops */ private final Ops tf; - /** The random number generator seed value */ private final long seed; /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ @@ -70,7 +62,7 @@ protected Metric(Ops tf, String name, long seed) { } this.seed = seed; this.name = name != null ? name : this.getClass().getSimpleName(); - this.tf = tf.withSubScope(this.name); + this.tf = tf.withName(this.getClass().getSimpleName()); } /** @@ -139,6 +131,13 @@ public final Op updateState( */ public abstract Operand result(Ops tf); + /** + * Resets any state variables to their initial values + * + * @return the control operation for doing the reset + */ + public abstract Op resetStates(); + /** * Gets the current result of the metric using the metric's {@link #getTF()} * @@ -161,36 +160,6 @@ public final Operand callOnce(Operand values, Operand sampleWeights) { return result(ltf); } - /** - * Adds a variable to collect metric values - * - * @param varName the name for the variable - * @param variable the variable - * @param initializer the initializer for the variable, if null, then the default for floating - * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution - * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for - * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} - * @param the date type for the variable - */ - protected void addVariable( - String varName, Variable variable, Initializer initializer) { - Map> variables = - variableMap.computeIfAbsent(tf.scope().env(), k -> new HashMap<>()); - variables.put(varName, new MetricVariable<>(tf, variable, initializer, seed, variable.type())); - } - - /** - * Gets the list of added variables - * - * @return the list of added variables - */ - public List> getVariables() { - List> result = new ArrayList<>(); - Map> variables = variableMap.get(tf.scope().env()); - if (variables != null) variables.values().forEach(mv -> result.add(mv.getVariable())); - return result; - } - /** * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. * @@ -201,71 +170,6 @@ protected String getVariableName(String varName) { return String.format("%s_%s", this.name, varName); } - /** - * Gets an Operation that initializes the variables. - * - * @param subScopeName the sub scope name - * @return the Operation used to initialize the variables. - */ - public Op initialize(String subScopeName) { - - List initializeOperations = initializeVarsList(subScopeName); - return tf.withControlDependencies(initializeOperations).noOp(); - } - - /** - * Gets the list of Operations that initializes the variables - * - * @param subScopeName the sub scope name - * @return the list of Operations that initializes the variables - */ - @SuppressWarnings("unchecked") - private List initializeVarsList(String subScopeName) { - Map> variables = variableMap.get(tf.scope().env()); - if (variables != null) - return variables.values().stream() - .map(metricVariable -> variableAssign(subScopeName, metricVariable)) - .collect(Collectors.toList()); - else return Collections.EMPTY_LIST; - } - - /** - * Resets all variables to their initial state - * - * @return An Operation that sets all variables to their initial state - */ - public Op resetStates() { - return initialize("resetStates"); - } - - /** - * Assigns a value to a Variable - * - *

This assumes the variable has already been initialized - * - * @param subScopeName the subscope for creating the variable - * @param mv the metric value used to assign the initializer to the variable. - * @return the variable add operation with necessary control dependencies - */ - private Operand variableAssign( - String subScopeName, MetricVariable mv) { - return tf.withSubScope(subScopeName).assign(mv.getVariable(), mv.initialize()); - } - - /** - * Gets a stored variable by name, Variables are cached first by the TensorFlow Environment, then - * by a variable name. - * - * @param varName the name assigned to the variable - * @return the variable, or null if the variable is not found. - */ - public Variable getVariable(String varName) { - Map> variables = variableMap.get(tf.scope().env()); - if (variables == null) return null; - MetricVariable mv = variables.get(varName); - return mv != null ? mv.getVariable() : null; - } - /** * Gets the TensorFlow Ops * @@ -283,4 +187,9 @@ public Ops getTF() { public String getName() { return name; } + + /** The random number generator seed value */ + public long getSeed() { + return seed; + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 3ec6540779c..2c387cc152e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; @@ -50,7 +49,6 @@ public abstract class Reduce extends Metri /** the variable that holds the count of the metric values */ protected Variable count; - protected boolean initialized; /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} @@ -81,25 +79,33 @@ protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Clas this.type = type; setupVars(); } - /** initialize the Variables */ - @SuppressWarnings("unchecked") + /** Initializes the Variables */ private void setupVars() { - Zeros fZeros = new Zeros<>(getTF()); - total = (Variable) getVariable(totalName); if (total == null) { - total = getTF().withSubScope(totalName).variable(Shape.scalar(), type); - addVariable(totalName, total, fZeros); + total = getTF().withName(totalName).variable(Shape.scalar(), type); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { - count = (Variable) getVariable(countName); if (count == null) { - count = getTF().withSubScope(countName).variable(Shape.scalar(), type); - addVariable(countName, count, fZeros); + count = getTF().withName(countName).variable(Shape.scalar(), type); } } } + /** {@inheritDoc} */ + public Op resetStates() { + List controls = new ArrayList<>(); + if (total != null) { + controls.add( + getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + } + if (count != null) { + controls.add( + getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + } + return getTF().withControlDependencies(controls).noOp(); + } + /** * Updates the metric variables based on the inputs. At least one input arg required for * values, an optional additional input for the sampleWeights @@ -110,7 +116,7 @@ private void setupVars() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); @@ -136,7 +142,6 @@ public List updateStateList(Operand values, Operand sampleWeights) { .math .mul(lValues, lSampleWeights); } catch (IllegalArgumentException ex) { - System.out.println("Reduce: Fall back from broadcast"); // reduce the values down to the rank of the samples int nDim = lValues.shape().numDimensions(); int wDim = lSampleWeights.shape().numDimensions(); @@ -224,8 +229,10 @@ public Variable getCount() { return count; } - /** Gets the type for the variables - * @return the type for the variables + /** + * Gets the type for the variables + * + * @return the type for the variables */ public Class getType() { return type; From 48390ea3067398881ebae41fafbf117e16a9030b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 7 Jan 2021 17:16:31 -0500 Subject: [PATCH 35/56] reformat code --- .../src/main/java/org/tensorflow/framework/metrics/Metric.java | 1 + .../main/java/org/tensorflow/framework/metrics/impl/Reduce.java | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index a6f2cf0f26d..9efb3dde20a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -32,6 +32,7 @@ public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; + private final long seed; /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 2c387cc152e..f304ad04cb4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -49,7 +49,6 @@ public abstract class Reduce extends Metri /** the variable that holds the count of the metric values */ protected Variable count; - /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} * From 6a74ce8c732aa4dfd34e5cea2087af91c3e79d21 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:02:27 -0500 Subject: [PATCH 36/56] Add tests for assertBroadcastable --- .../framework/metrics/impl/MetricsHelper.java | 49 ++- .../metrics/impl/WeightBroadcastTest.java | 335 ++++++++++++++++++ 2 files changed, 365 insertions(+), 19 deletions(-) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 9699ccd323c..5ecc06a388f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -18,6 +18,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.SetDiff1d; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -30,7 +31,6 @@ import java.util.List; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; -import static org.tensorflow.framework.utils.CastHelper.cast; /** * These are helper methods for Metrics and will be module private when Java modularity is applied @@ -42,7 +42,12 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to values + * Asserts that the sampleWeights can be broadcast to the same shape as values + * + * + *

In losses and metrics, limited weight broadcasting is supported. Weights be either scalar, + * or the same rank as the target values, with each dimension either 1, or the same as the + * corresponding values dimension. * * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. @@ -54,9 +59,10 @@ public class MetricsHelper { * incorrect shape that prohibit broadcasting to to values */ @SuppressWarnings("unchecked") - public static Op broadcastWeights( + public static Op assertBroadcastable( Ops tf, Operand sampleWeights, Operand values) { + // try static check for exact match Operand weightsShape = tf.shape(sampleWeights); Operand weightsRank = tf.rank(sampleWeights); Shape weightsShapeStatic = sampleWeights.shape(); @@ -67,9 +73,9 @@ public static Op broadcastWeights( Shape valuesShapeStatic = values.shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); - if (weightsRankStatic != -1 && valuesRankStatic != -1) { + if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { if (weightsRankStatic == 0) { - return tf.withSubScope("static_scalar_check_success") + return tf.withSubScope("staticScalarCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } @@ -85,7 +91,7 @@ public static Op broadcastWeights( } for (int i = 0; i < valuesRankStatic; i++) { - if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", @@ -95,12 +101,12 @@ public static Op broadcastWeights( weightsShapeStatic.toString())); } } - return tf.withSubScope("static_dims_check_success") + return tf.withSubScope("staticDimsCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } // Dynamic checks. - Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( tf.constant(ASSERT_BROADCAST_ERROR_PREFIX), @@ -108,14 +114,13 @@ public static Op broadcastWeights( weightsShape, tf.constant("values.shape="), valuesShape, - tf.constant("is_scalar="), - is_scalar); + tf.constant("isScalar="), + isScalar); + + Operand validNonsclar = + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape); - Operand isValidShape = - tf.select( - is_scalar, - is_scalar, - hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); + Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); } @@ -137,7 +142,7 @@ private static Operand hasValidNonscalarShape( Operand weightsShape, Operand valuesRank, Operand valuesShape) { - tf = tf.withSubScope("has_valid_nonscalar_shape"); + tf = tf.withSubScope("hasValidNonscalarShape"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); } @@ -153,9 +158,15 @@ private static Operand hasValidNonscalarShape( */ private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { - tf = tf.withSubScope("has_invalid_dims"); - Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); - return tf.math.equal(cast(tf, tf.constant(0), diff.asOutput().type()), diff); + tf = tf.withSubScope("hasValidDims"); + Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); + Operand validDims = + tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + SetDiff1d invalidDimsDiff = + tf.setDiff1d(tf.shape.flatten(valuesShape2d), tf.shape.flatten(validDims)); + Operand invalidDims = invalidDimsDiff.out(); + Operand numInvalidDims = tf.size(invalidDims); + return tf.math.equal(tf.constant(0), numInvalidDims); } // alias for mean diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java new file mode 100644 index 00000000000..c89cff93dc2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java @@ -0,0 +1,335 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class WeightBroadcastTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + private void testValid( + TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { + + Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); + testSession.run(staticOp); + + // dynamic test + Operand weightsPlaceholder = tf.placeholder(type); + Operand valuesPlaceholder = tf.placeholder(type); + + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); + try (Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1)) { + + Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); + + testSession + .getGraphSession() + .runner() + .feed(weightsPlaceholder, weightsTensor) + .feed(valuesPlaceholder, valuesTensor) + .addTarget(dynamicOp) + .run(); + } + } + + @Test + public void testValidScalar() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new float[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(5f); + testValid(testSession, tf, weights, values, TFloat32.class); + } + } + + @Test + public void test1x1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new double[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new double[][][] {{{5}}}); + testValid(testSession, tf, weights, values, TFloat64.class); + } + } + + @Test + public void test1x1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new long[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + testValid(testSession, tf, weights, values, TInt64.class); + } + } + + @Test + public void test1xNx1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void test1xNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNx1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNx1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = + tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNxNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = + tf.constant( + new int[][][] { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testInvalid1x1() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][] {{5}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidOnesExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][][] {{{{5}}}} ); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + + Operand weights = tf.constant(new int[][][][] { + {{ { 5},{ 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} }, + {{ { 2}, {17}, {11}, { 3}}, {{ 2}, {17}, {11}, { 3}} }, + {{ { 5}, { 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} } + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][][] {{ + { { 5, 7, 11, 3}, { 2, 12, 7, 5} }, + { { 2, 17, 11, 3}, { 2, 17, 11, 3} }, + { { 5, 7, 11, 3}, { 2, 12, 7, 5} } + }}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } +} From c62481b1384144fd9dd2cd52394a8602421f5e39 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:41:00 -0500 Subject: [PATCH 37/56] Change type to resultType --- .../tensorflow/framework/metrics/impl/MeanMetricWrapper.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index cd17e2a9de4..173167c5370 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -94,8 +94,8 @@ public List updateStateList( throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getType()); + Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); + Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); From d475b1accdb805a4a0327c6fd7358d084cb64b90 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:42:18 -0500 Subject: [PATCH 38/56] Added V data type for sampleWeights so that it is not forced to be the same type as the return or internal variables, --- .../tensorflow/framework/metrics/Metric.java | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 9efb3dde20a..123abae61d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -76,7 +76,7 @@ protected Metric(Ops tf, String name, long seed) { * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -88,7 +88,7 @@ public List updateStateList(Operand values, Operand sampleWeights) { * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the sample weights + * @param the data type for the labels * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) @@ -104,7 +104,7 @@ public List updateStateList( * @param sampleWeights sample weights to be applied to values, may be null. * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -115,7 +115,7 @@ public final Op updateState(Operand values, Operand sampleWeights) { * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the sample weights + * @param the data type for the labels * @return the Operation to update the metric state */ public final Op updateState( @@ -127,10 +127,9 @@ public final Op updateState( /** * Gets the current result of the metric * - * @param tf the TensorFlow Ops used to create the result * @return the result, possibly with control dependencies */ - public abstract Operand result(Ops tf); + public abstract Operand result(); /** * Resets any state variables to their initial values @@ -139,14 +138,6 @@ public final Op updateState( */ public abstract Op resetStates(); - /** - * Gets the current result of the metric using the metric's {@link #getTF()} - * - * @return the result, possibly with control dependencies - */ - public Operand result() { - return result(this.tf); - } /** * Calls update state once, followed by a call to get the result @@ -158,7 +149,7 @@ public Operand result() { public final Operand callOnce(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); - return result(ltf); + return ltf.identity(result()); } /** From 8f530cc1bc93c20ead50d56d5cbe8d9166205098 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:42:49 -0500 Subject: [PATCH 39/56] change 'type' to 'resultType' --- .../framework/metrics/impl/Reduce.java | 62 ++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index f304ad04cb4..fb8e39f3f1f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -43,7 +43,7 @@ public abstract class Reduce extends Metri private final String totalName; private final String countName; - private final Class type; + private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; /** the variable that holds the count of the metric values */ @@ -56,10 +56,10 @@ public abstract class Reduce extends Metri * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result + * @param resultType the type for the variables and result */ - protected Reduce(Ops tf, String name, long seed, Class type) { - this(tf, name, MetricReduction.SUM, seed, type); + protected Reduce(Ops tf, String name, long seed, Class resultType) { + this(tf, name, MetricReduction.SUM, seed, resultType); } /** @@ -68,25 +68,25 @@ protected Reduce(Ops tf, String name, long seed, Class type) { * @param reduction The type of metric reduction to apply * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variables and result + * @param resultType the type for the variables and result */ - protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class type) { + protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class resultType) { super(tf, name, seed); this.reduction = reduction; this.totalName = this.getVariableName(TOTAL); this.countName = this.getVariableName(COUNT); - this.type = type; + this.resultType = resultType; setupVars(); } /** Initializes the Variables */ private void setupVars() { if (total == null) { - total = getTF().withName(totalName).variable(Shape.scalar(), type); + total = getTF().withName(totalName).variable(Shape.scalar(), resultType); } if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE || reduction == MetricReduction.WEIGHTED_MEAN) { if (count == null) { - count = getTF().withName(countName).variable(Shape.scalar(), type); + count = getTF().withName(countName).variable(Shape.scalar(), resultType); } } } @@ -115,7 +115,7 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); @@ -133,7 +133,7 @@ public List updateStateList(Operand values, Operand sampleWeights) { lSampleWeights = tuple.getSampleWeights(); try { - Op broadcastWeightsCheck = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + Op broadcastWeightsCheck = MetricsHelper.assertBroadcastable(getTF(), lSampleWeights, lValues); lValues = getTF() .withSubScope("broadcastWeightsCheck") @@ -141,16 +141,20 @@ public List updateStateList(Operand values, Operand sampleWeights) { .math .mul(lValues, lSampleWeights); } catch (IllegalArgumentException ex) { - // reduce the values down to the rank of the samples - int nDim = lValues.shape().numDimensions(); - int wDim = lSampleWeights.shape().numDimensions(); - int numAxes = nDim - wDim; - int[] axes = new int[numAxes]; - for (int i = 0; i < numAxes; i++) axes[i] = i + wDim; - if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); - } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + // if we get here we have static shapes with either + // different ranks or different dimension sizes. + // first, reduce the values down to the rank of the samples + int valuesDim = lValues.shape().numDimensions(); + int weightsDim = lSampleWeights.shape().numDimensions(); + int numAxes = Math.min(0, valuesDim - weightsDim); + if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. + int[] axes = new int[numAxes]; + for (int i = 0; i < numAxes; i++) axes[i] = i + weightsDim; + if (reduction == MetricReduction.SUM) { + lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + } else { + lValues = getTF().math.mean(lValues, getTF().constant(axes)); + } } lValues = getTF().math.mul(lValues, lSampleWeights); } @@ -164,18 +168,18 @@ public List updateStateList(Operand values, Operand sampleWeights) { if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); break; case WEIGHTED_MEAN: if (lSampleWeights == null) { - numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), type); + numValues = CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); } else { numValues = CastHelper.cast( getTF(), getTF() .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), - type); + resultType); } break; default: @@ -192,16 +196,16 @@ public List updateStateList(Operand values, Operand sampleWeights) { /** {@inheritDoc} */ @Override - public Operand result(Ops rtf) { + public Operand result() { Operand fResult; switch (this.reduction) { case SUM: - fResult = rtf.identity(total); + fResult = getTF().identity(total); break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = rtf.math.divNoNan(total, CastHelper.cast(rtf, count, type)); + fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( @@ -233,7 +237,7 @@ public Variable getCount() { * * @return the type for the variables */ - public Class getType() { - return type; + public Class getResultType() { + return resultType; } } From 947482ab482af050cb386f7d166f1ae8ab0b188c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:43:43 -0500 Subject: [PATCH 40/56] clean up mean and fix assert assertBroadcastable --- .../framework/metrics/impl/MetricsHelper.java | 60 +++++++++++-------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 5ecc06a388f..eb7c1fbd221 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -23,6 +23,7 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -91,7 +92,8 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) + && weightsShapeStatic.size(i) != 1) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", @@ -152,7 +154,7 @@ private static Operand hasValidNonscalarShape( * * @param tf the TensorFlow Ops * @param weightsShape the operand for the shape of the sample weights - * @param valuesShape the operand for the shape of the sample weights + * @param valuesShape the operand for the shape of the values * @param the data type for the operands * @return a boolean operand to determine if the shapes have valid dimensions or not. */ @@ -163,7 +165,7 @@ private static Operand hasValidDims( Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); SetDiff1d invalidDimsDiff = - tf.setDiff1d(tf.shape.flatten(valuesShape2d), tf.shape.flatten(validDims)); + tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); Operand invalidDims = invalidDimsDiff.out(); Operand numInvalidDims = tf.size(invalidDims); return tf.math.equal(tf.constant(0), numInvalidDims); @@ -178,9 +180,10 @@ private static Operand hasValidDims( * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param the type of the Operand. + * @param the data type for the result * @return the mean of the operand */ - public static Operand mean(Ops tf, Operand x) { + public static Operand mean(Ops tf, Operand x) { return mean(tf, x, null, false); } @@ -190,58 +193,63 @@ public static Operand mean(Ops tf, Operand x) { * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param axis Axes to compute the mean. + * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axis. - * @return the mean of the operand, alongside the specified axis. + * @param the type of the axes. + * @param the data type for the result + * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axis) { - return mean(tf, x, axis, false); + public static Operand mean( + Ops tf, Operand x, Operand axes) { + return mean(tf, x, axes, false); } /** - * Calculate the mean of the operand, along all axis. + * Calculate the mean of the operand, along all axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axis + * false, the rank of the tensor is reduced by 1 for each entry in axes * . If keepdims is true, the reduced dimensions are retained * with length 1. * @param the type of the operand + * @param the data type for the result * @return the mean of elements of x. */ - public static Operand mean(Ops tf, Operand x, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } /** - * Calculate the mean of the operand, alongside the specified axis. + * Calculate the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param axis Axes to compute the mean. + * @param axes Axes to compute the mean. * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is `true`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the * * reduced dimensions are retained with length 1. * @param the data type of the Operand - * @param the data type of the axis + * @param the data type of the axes + * @param the data type for the result * @return the mean of elements of `x`. */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Operand mean( - Ops tf, Operand x, Operand axis, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { // Cannot use generics here because xf may change from TBool to TFloat32 - Operand xf; - if (x.asOutput().type() == TBool.class) { - xf = tf.dtypes.cast(x, TFloat32.class); + Operand xf; + if (x.type().equals(TBool.class)) { + xf = (Operand) tf.dtypes.cast(x, TFloat32.class); } else { - xf = x; + xf = (Operand) x; } - if (axis == null) { - axis = allAxes(tf, xf); + if (axes == null) { + axes = (Operand) allAxes(tf, xf); } - return tf.math.mean(xf, axis, Mean.keepDims(keepDims)); + Operand theMean = tf.math.mean(xf, axes, Mean.keepDims(keepDims)); + return x.type().equals(TBool.class) ? tf.dtypes.cast(theMean, TBool.class) : theMean; } } From f473568d2d6236c78d16cd96362bc7c3be539ea4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 11 Jan 2021 18:44:07 -0500 Subject: [PATCH 41/56] fix error message --- .../org/tensorflow/framework/metrics/impl/MetricVariable.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index c5c5dbb2ab2..aae5a8f30c4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -79,7 +79,7 @@ public MetricVariable( } else { throw new IllegalArgumentException( String.format( - "Type %s is not a supported for metric variables", + "Type %s is not supported for metric variables", type.getSimpleName())); } } else { From 99ff15ddf423cabaa536b15bb67fe221a72efb4b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:19:07 -0500 Subject: [PATCH 42/56] Change sampleWeights to have its own generic type --- .../org/tensorflow/framework/metrics/Metric.java | 16 ++++++++++------ .../metrics/impl/MeanMetricWrapper.java | 5 +++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 123abae61d7..20151eb1408 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -74,9 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state + * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList(Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -89,11 +90,12 @@ public List updateStateList(Operand values, Operand the data type for the labels + * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -102,9 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState(Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -116,10 +119,11 @@ public final Op updateState(Operand values, Operand sa * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. * @param the data type for the labels + * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, Operand predictions, Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 173167c5370..98447142da6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -86,10 +86,11 @@ protected void setLoss(LossMetric loss) { * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param the datatype of the predictions + * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } From f4e3e043f10149089b9027de41fa047cea81677d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:19:56 -0500 Subject: [PATCH 43/56] Add commment about invalid tests expecting IllegalArgumentExceptions --- .../metrics/impl/WeightBroadcastTest.java | 118 ++++++++++-------- 1 file changed, 65 insertions(+), 53 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java index c89cff93dc2..08e19f82a89 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java @@ -204,8 +204,14 @@ public void testNxNxN() { } } + // Note: For invalid tests, either NotBroadcastableException is thrown for static shapes or + // TFInvalidInvalidException is thrown for dynamic shapes. Both of these extend + // IllegalArgumentException, + // To simply the assertThrows, only IllegalArgumentException is expected. + // The private method, testValid, tests for both static and dynamic shapes. @Test public void testInvalid1x1() { + assertThrows( IllegalArgumentException.class, () -> { @@ -267,69 +273,75 @@ public void testInvalidSuffixMatch() { @Test public void testInvalidOnesExtraDim() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][][] {{{{5}}}} ); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); } @Test public void testInvalidPrefixMatchExtraDim() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); - Operand weights = tf.constant(new int[][][][] { - {{ { 5},{ 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} }, - {{ { 2}, {17}, {11}, { 3}}, {{ 2}, {17}, {11}, { 3}} }, - {{ { 5}, { 7}, {11}, { 3}}, {{ 2}, {12}, { 7}, { 5}} } - }); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); + Operand weights = + tf.constant( + new int[][][][] { + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); } @Test public void testInvalidSuffixMatchExtraDim() { assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][][] {{ - { { 5, 7, 11, 3}, { 2, 12, 7, 5} }, - { { 2, 17, 11, 3}, { 2, 17, 11, 3} }, - { { 5, 7, 11, 3}, { 2, 12, 7, 5} } - }}); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = + tf.constant( + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }); + Operand weights = + tf.constant( + new int[][][][] { + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); } } From 4efdb6220dbf9e1b659f68ea33d0d1536f963ceb Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:20:39 -0500 Subject: [PATCH 44/56] Add this exception instead of the more generic IllegalArgumentException when static shapes cannot boradcast. --- .../metrics/exceptions/NotBroadcastableException.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java index 66640e72f50..73f07b977c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java @@ -17,23 +17,22 @@ import org.tensorflow.ndarray.Shape; /** - * Exception that indicates that static shapes are not able to broadcast among each other during - * arithmetic operations. Static shapes do not have unknown rank or any unknown dimensions {@link - * Shape#hasUnknownDimension()}. The term broadcasting describes how TensorFlow treats arrays with - * different shapes during arithmetic operations. + * Exception that indicates that static shapes are not able to broadcast among each other during arithmetic operations. + * Static shapes do not have unknown rank or any unknown dimensions {@link Shape#hasUnknownDimension()}. + * The term broadcasting describes how TensorFlow treats arrays with different shapes during arithmetic operations. * *

Broadcasting is the process of making arrays to have compatible shapes for arithmetic * operations. Two shapes are compatible if for each dimension pair they are either equal or one of * them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing * dimensions, and works its way forward. * + * * @see Numpy Broadcasting */ public class NotBroadcastableException extends IllegalArgumentException { /** * Creates a new NotBroadcastableException exception with the specified detail message - * * @param message the detail message. */ public NotBroadcastableException(String message) { @@ -42,7 +41,6 @@ public NotBroadcastableException(String message) { /** * Creates a new NotBroadcastableException exception with the specified detail message - * * @param message the detail message. * @param cause the cause */ From b0a143ea8f35c4958a4e467446a145bf4005ac0a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:22:51 -0500 Subject: [PATCH 45/56] change IllegalArgumentException to NotBroadcastableException. change hasValidNonscalarShape to canBroadcastNonscalarShapes change hasValidNonscalarShape to canBroadcastNonscalarShapes --- .../framework/metrics/impl/MetricsHelper.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index eb7c1fbd221..b25a03b07c9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -56,8 +57,8 @@ public class MetricsHelper { * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values * @param the type of Operand - * @throws IllegalArgumentException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to to values + * @throws NotBroadcastableException If static checks determine sampleWeights has an + * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -81,7 +82,7 @@ public static Op assertBroadcastable( .noOp(); } if (weightsRankStatic != valuesRankStatic) { - throw new IllegalArgumentException( + throw new NotBroadcastableException( String.format( "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, @@ -94,7 +95,7 @@ public static Op assertBroadcastable( for (int i = 0; i < valuesRankStatic; i++) { if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) && weightsShapeStatic.size(i) != 1) { - throw new IllegalArgumentException( + throw new NotBroadcastableException( String.format( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, @@ -120,7 +121,7 @@ public static Op assertBroadcastable( isScalar); Operand validNonsclar = - hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape); + canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); @@ -138,7 +139,7 @@ public static Op assertBroadcastable( * @param the data type for the operands * @return a boolean operand to determine if the Shape is scalar or not. */ - private static Operand hasValidNonscalarShape( + private static Operand canBroadcastNonscalarShapes( Ops tf, Operand weightsRank, Operand weightsShape, @@ -146,7 +147,7 @@ private static Operand hasValidNonscalarShape( Operand valuesShape) { tf = tf.withSubScope("hasValidNonscalarShape"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); - return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); + return tf.select(isSameRank, canBroadcastDims(tf, weightsShape, valuesShape), isSameRank); } /** @@ -158,7 +159,7 @@ private static Operand hasValidNonscalarShape( * @param the data type for the operands * @return a boolean operand to determine if the shapes have valid dimensions or not. */ - private static Operand hasValidDims( + private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("hasValidDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); From 5e0501836b17d47b7195b54eb78c8c53231ef41c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 12 Jan 2021 18:23:58 -0500 Subject: [PATCH 46/56] reformat code --- .../org/tensorflow/framework/metrics/Metric.java | 1 - .../org/tensorflow/framework/metrics/Metrics.java | 4 ++-- .../exceptions/NotBroadcastableException.java | 10 ++++++---- .../framework/metrics/impl/MeanMetricWrapper.java | 2 +- .../framework/metrics/impl/MetricVariable.java | 7 +++---- .../framework/metrics/impl/MetricsHelper.java | 7 +++---- .../tensorflow/framework/metrics/impl/Reduce.java | 14 +++++++++----- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 20151eb1408..57e332a0843 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -142,7 +142,6 @@ public final Op updateState( */ public abstract Op resetStates(); - /** * Calls update state once, followed by a call to get the result * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index b8e79efa450..e2cd5e368c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -117,8 +117,8 @@ public static Operand l2Normalize(Ops tf, Operand x, i * @param tf The TensorFlow ops * @param x The operand to normalize * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the divisor if - * norm < sqrt(epsilon). + * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the + * divisor if norm < sqrt(epsilon). * @param The data type for the values. * @return the normalized values of x. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java index 73f07b977c2..66640e72f50 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java @@ -17,22 +17,23 @@ import org.tensorflow.ndarray.Shape; /** - * Exception that indicates that static shapes are not able to broadcast among each other during arithmetic operations. - * Static shapes do not have unknown rank or any unknown dimensions {@link Shape#hasUnknownDimension()}. - * The term broadcasting describes how TensorFlow treats arrays with different shapes during arithmetic operations. + * Exception that indicates that static shapes are not able to broadcast among each other during + * arithmetic operations. Static shapes do not have unknown rank or any unknown dimensions {@link + * Shape#hasUnknownDimension()}. The term broadcasting describes how TensorFlow treats arrays with + * different shapes during arithmetic operations. * *

Broadcasting is the process of making arrays to have compatible shapes for arithmetic * operations. Two shapes are compatible if for each dimension pair they are either equal or one of * them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing * dimensions, and works its way forward. * - * * @see Numpy Broadcasting */ public class NotBroadcastableException extends IllegalArgumentException { /** * Creates a new NotBroadcastableException exception with the specified detail message + * * @param message the detail message. */ public NotBroadcastableException(String message) { @@ -41,6 +42,7 @@ public NotBroadcastableException(String message) { /** * Creates a new NotBroadcastableException exception with the specified detail message + * * @param message the detail message. * @param cause the cause */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 98447142da6..e2f1345f356 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -89,7 +89,7 @@ protected void setLoss(LossMetric loss) { * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( + public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java index aae5a8f30c4..6b208c0d7bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java @@ -61,7 +61,8 @@ public MetricVariable(Ops tf, Variable variable, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variable - * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer is null + * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer + * is null */ @SuppressWarnings("unchecked") public MetricVariable( @@ -78,9 +79,7 @@ public MetricVariable( this.initializer = new Zeros<>(tf); } else { throw new IllegalArgumentException( - String.format( - "Type %s is not supported for metric variables", - type.getSimpleName())); + String.format("Type %s is not supported for metric variables", type.getSimpleName())); } } else { this.initializer = initializer; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index b25a03b07c9..05bfe17a1be 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -58,7 +58,7 @@ public class MetricsHelper { * can be broadcast to values * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -121,7 +121,7 @@ public static Op assertBroadcastable( isScalar); Operand validNonsclar = - canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); + canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); @@ -165,8 +165,7 @@ private static Operand canBroadcastDims( Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); - SetDiff1d invalidDimsDiff = - tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); + SetDiff1d invalidDimsDiff = tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); Operand invalidDims = invalidDimsDiff.out(); Operand numInvalidDims = tf.size(invalidDims); return tf.math.equal(tf.constant(0), numInvalidDims); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index fb8e39f3f1f..6e1795af2eb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -133,7 +133,8 @@ public List updateStateList(Operand values, Operand List updateStateList(Operand values, Operand 0) { // values rank is greater than weights rank, reduce values to weights rank. + if (numAxes + > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsDim; if (reduction == MetricReduction.SUM) { @@ -168,18 +170,20 @@ public List updateStateList(Operand values, Operand Date: Wed, 13 Jan 2021 07:07:26 -0500 Subject: [PATCH 47/56] Fis=x Javadoc move the dynamic shapes and rank down to the dynamic section so they are created needlessly when static Fix if statement to check for unknown size and unknown dimensions --- .../framework/metrics/impl/MetricsHelper.java | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 05bfe17a1be..00af7a6d1af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -47,8 +47,8 @@ public class MetricsHelper { * Asserts that the sampleWeights can be broadcast to the same shape as values * * - *

In losses and metrics, limited weight broadcasting is supported. Weights be either scalar, - * or the same rank as the target values, with each dimension either 1, or the same as the + *

In losses and metrics, limited weight broadcasting is supported. Weights must be either + * scalar, or the same rank as the target values, with each dimension either 1, or the same as the * corresponding values dimension. * * @param tf the TensorFlow Ops @@ -65,17 +65,17 @@ public static Op assertBroadcastable( Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match - Operand weightsShape = tf.shape(sampleWeights); - Operand weightsRank = tf.rank(sampleWeights); + Shape weightsShapeStatic = sampleWeights.shape(); int weightsRankStatic = weightsShapeStatic.numDimensions(); - Operand valuesShape = tf.shape(values); - Operand valuesRank = tf.rank(values); Shape valuesShapeStatic = values.shape(); int valuesRankStatic = valuesShapeStatic.numDimensions(); - if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { + // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { + if (!weightsShapeStatic.isUnknown() + && !valuesShapeStatic.isUnknown() + && !weightsShapeStatic.hasUnknownDimension() & !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) @@ -109,6 +109,11 @@ public static Op assertBroadcastable( .noOp(); } // Dynamic checks. + Operand weightsShape = tf.shape(sampleWeights); + Operand weightsRank = tf.rank(sampleWeights); + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( From a2761f049271e5564038a672877fd77dd6aa6dd4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 17 Jan 2021 13:00:06 -0500 Subject: [PATCH 48/56] Fix Reduce to use boradcastWeights, renamed WeightBroadcastTest to AssertBroadcastableTest and added BroadcastWeightsTest --- .../tensorflow/framework/metrics/Metric.java | 5 +- .../framework/metrics/impl/MetricsHelper.java | 31 +- .../framework/metrics/impl/Reduce.java | 26 +- .../metrics/impl/AssertBroadcastableTest.java | 7 +- .../metrics/impl/WeightBroadcastTest.java | 347 ------------------ 5 files changed, 44 insertions(+), 372 deletions(-) delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 57e332a0843..bbb2aa73da2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -33,6 +33,7 @@ public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; + /** The seed for random number generation */ private final long seed; /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ @@ -148,8 +149,10 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies + * @param the data type for the sampleWeights. */ - public final Operand callOnce(Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 00af7a6d1af..fbe50151854 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -75,7 +75,8 @@ public static Op assertBroadcastable( // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { if (!weightsShapeStatic.isUnknown() && !valuesShapeStatic.isUnknown() - && !weightsShapeStatic.hasUnknownDimension() & !valuesShapeStatic.hasUnknownDimension()) { + && !weightsShapeStatic.hasUnknownDimension() + && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") .withControlDependencies(Collections.EMPTY_LIST) @@ -176,6 +177,34 @@ private static Operand canBroadcastDims( return tf.math.equal(tf.constant(0), numInvalidDims); } + /** + * Broadcast `weights` to the same shape as `values`. + * + * @param tf the TensorFlow ops + * @param weights `Tensor` whose shape is broadcastable to `values` + * @param values Tensor` of any shape + * @param the type of Operands + * @return weights broadcast to values shape + */ + public static Operand broadcastWeights( + Ops tf, Operand weights, Operand values) { + + Shape weightsShape = weights.shape(); + Shape valuesShape = values.shape(); + + if (!weightsShape.hasUnknownDimension() + && !valuesShape.hasUnknownDimension() + && weightsShape.isCompatibleWith(valuesShape)) { + return weights; + } + + Ops ctf = + tf.withSubScope("broadcastWeights") + .withControlDependencies( + Collections.singletonList(assertBroadcastable(tf, weights, tf.onesLike(values)))); + return ctf.math.mul(weights, tf.onesLike(values)); + } + // alias for mean /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 6e1795af2eb..771f4804dea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -27,7 +27,6 @@ import org.tensorflow.types.family.TNumber; import java.util.ArrayList; -import java.util.Collections; import java.util.List; /** @@ -132,39 +131,32 @@ public List updateStateList(Operand values, Operand 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; - for (int i = 0; i < numAxes; i++) axes[i] = i + weightsDim; + for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { lValues = getTF().reduceSum(lValues, getTF().constant(axes)); } else { lValues = getTF().math.mean(lValues, getTF().constant(axes)); } } - lValues = getTF().math.mul(lValues, lSampleWeights); } + lValues = getTF().math.mul(lValues, lSampleWeights); } - Operand valueSum = getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), valueSum, total.type())); + getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 63d666f8640..af4a89692d1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -72,6 +72,7 @@ private void testValid( testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); try (Tensor weightsTensor = tensors.get(0); Tensor valuesTensor = tensors.get(1)) { + Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession @@ -89,7 +90,6 @@ public void testValidScalar() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = tf.constant(valueArrayF); Operand weights = tf.constant(5f); testValid(testSession, tf, weights, values, TFloat32.class); @@ -101,7 +101,6 @@ public void test1x1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = tf.constant(valueArrayD); Operand weights = tf.constant(new double[][][] {{{5}}}); testValid(testSession, tf, weights, values, TFloat64.class); @@ -135,7 +134,6 @@ public void test1xNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -147,7 +145,6 @@ public void testNx1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -159,7 +156,6 @@ public void testNx1xN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); @@ -172,7 +168,6 @@ public void testNxNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java deleted file mode 100644 index 08e19f82a89..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/WeightBroadcastTest.java +++ /dev/null @@ -1,347 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.Tensor; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; - -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertThrows; - -public class WeightBroadcastTest { - - private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - - private void testValid( - TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { - - Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); - testSession.run(staticOp); - - // dynamic test - Operand weightsPlaceholder = tf.placeholder(type); - Operand valuesPlaceholder = tf.placeholder(type); - - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { - - Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); - - testSession - .getGraphSession() - .runner() - .feed(weightsPlaceholder, weightsTensor) - .feed(valuesPlaceholder, valuesTensor) - .addTarget(dynamicOp) - .run(); - } - } - - @Test - public void testValidScalar() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new float[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(5f); - testValid(testSession, tf, weights, values, TFloat32.class); - } - } - - @Test - public void test1x1x1() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new double[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new double[][][] {{{5}}}); - testValid(testSession, tf, weights, values, TFloat64.class); - } - } - - @Test - public void test1x1xN() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new long[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); - testValid(testSession, tf, weights, values, TInt64.class); - } - } - - @Test - public void test1xNx1() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - } - - @Test - public void test1xNxN() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - } - - @Test - public void testNx1x1() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - } - - @Test - public void testNx1xN() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = - tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - } - - @Test - public void testNxNxN() { - // no exception should be thrown - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = - tf.constant( - new int[][][] { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} - }); - testValid(testSession, tf, weights, values, TInt32.class); - } - } - - // Note: For invalid tests, either NotBroadcastableException is thrown for static shapes or - // TFInvalidInvalidException is thrown for dynamic shapes. Both of these extend - // IllegalArgumentException, - // To simply the assertThrows, only IllegalArgumentException is expected. - // The private method, testValid, tests for both static and dynamic shapes. - @Test - public void testInvalid1x1() { - - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][] {{5}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); - } - - @Test - public void testInvalidPrefixMatch() { - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); - } - - @Test - public void testInvalidSuffixMatch() { - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); - } - - @Test - public void testInvalidOnesExtraDim() { - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = tf.constant(new int[][][][] {{{{5}}}}); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); - } - - @Test - public void testInvalidPrefixMatchExtraDim() { - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - - Operand weights = - tf.constant( - new int[][][][] { - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, - {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} - }); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); - } - - @Test - public void testInvalidSuffixMatchExtraDim() { - assertThrows( - IllegalArgumentException.class, - () -> { - try (TestSession testSession = TestSession.createTestSession(tfMode)) { - Ops tf = testSession.getTF(); - Operand values = - tf.constant( - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} - }); - Operand weights = - tf.constant( - new int[][][][] { - { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} - } - }); - testValid(testSession, tf, weights, values, TInt32.class); - } - }); - } -} From c2efa2ac8b52e3d6861c73e044ab68214dada4b3 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 17 Jan 2021 13:09:32 -0500 Subject: [PATCH 49/56] Added comment to count to indicate that it may be weighted. --- .../java/org/tensorflow/framework/metrics/impl/Reduce.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 771f4804dea..8e48cb4e573 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -45,7 +45,8 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values */ + /** the variable that holds the count of the metric values. + * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ protected Variable count; /** From 9db5767780345671850a55a25679f59aa913e725 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 18 Jan 2021 20:22:24 -0500 Subject: [PATCH 50/56] Added SetsOps and fixed AssertBroadcastable to use SetsOps methods, --- .../framework/metrics/impl/MetricsHelper.java | 127 +++++++++++++----- .../framework/metrics/impl/SetsOps.java | 9 +- .../metrics/impl/AssertBroadcastableTest.java | 7 +- .../framework/metrics/impl/SetsOpsTest.java | 46 +++---- 4 files changed, 119 insertions(+), 70 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index fbe50151854..ad8ff58e417 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -19,10 +19,10 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.SetDiff1d; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; @@ -33,6 +33,7 @@ import java.util.List; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; /** * These are helper methods for Metrics and will be module private when Java modularity is applied @@ -126,10 +127,17 @@ public static Op assertBroadcastable( tf.constant("isScalar="), isScalar); - Operand validNonsclar = + // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a + // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. + Operand reshapedWeights = + tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); + weightsShape = tf.shape(reshapedWeights); + weightsRank = tf.rank(reshapedWeights); + + Operand validNonscalar = canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); - Operand isValidShape = tf.select(isScalar, isScalar, validNonsclar); + Operand isValidShape = tf.select(isScalar, isScalar, validNonscalar); return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); } @@ -151,7 +159,7 @@ private static Operand canBroadcastNonscalarShapes( Operand weightsShape, Operand valuesRank, Operand valuesShape) { - tf = tf.withSubScope("hasValidNonscalarShape"); + tf = tf.withSubScope("canBroadcastNonscalarShapes"); Operand isSameRank = tf.math.equal(valuesRank, weightsRank); return tf.select(isSameRank, canBroadcastDims(tf, weightsShape, valuesShape), isSameRank); } @@ -167,22 +175,23 @@ private static Operand canBroadcastNonscalarShapes( */ private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { - tf = tf.withSubScope("hasValidDims"); + tf = tf.withSubScope("canBroadcastDims"); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); - SetDiff1d invalidDimsDiff = tf.setDiff1d(weightsShape, tf.shape.flatten(validDims)); - Operand invalidDims = invalidDimsDiff.out(); - Operand numInvalidDims = tf.size(invalidDims); + Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); + + Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + Operand numInvalidDims = tf.size(diffResult); return tf.math.equal(tf.constant(0), numInvalidDims); } /** - * Broadcast `weights` to the same shape as `values`. + * Broadcast weights to the same shape as values. * * @param tf the TensorFlow ops - * @param weights `Tensor` whose shape is broadcastable to `values` - * @param values Tensor` of any shape + * @param weights Operand whose shape is broadcastable to values. + * @param values Operand of any shape * @param the type of Operands * @return weights broadcast to values shape */ @@ -205,7 +214,7 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - // alias for mean + // aliases for mean /** * Calculate the mean of the operand, along all axes and keepDims is false @@ -214,10 +223,9 @@ public static Operand broadcastWeights( * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param the type of the Operand. - * @param the data type for the result * @return the mean of the operand */ - public static Operand mean(Ops tf, Operand x) { + public static Operand mean(Ops tf, Operand x) { return mean(tf, x, null, false); } @@ -230,16 +238,15 @@ public static Operand mean(Ops tf, Opera * @param axes Axes to compute the mean. * @param the type of the Operand. * @param the type of the axes. - * @param the data type for the result * @return the mean of the operand, along the specified axes. */ - public static Operand mean( + public static Operand mean( Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } /** - * Calculate the mean of the operand, along all axes. + * Calculates the mean of the operand, along all axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -248,16 +255,17 @@ public static Operand< * . If keepdims is true, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @param the data type for the result * @return the mean of elements of x. */ - public static Operand mean( + public static Operand mean( Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } + + /** - * Calculate the mean of the operand, alongside the specified axes. + * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -267,23 +275,74 @@ public static Operand mean( * * reduced dimensions are retained with length 1. * @param the data type of the Operand * @param the data type of the axes - * @param the data type for the result - * @return the mean of elements of `x`. + * @return the mean of elements of x. */ - @SuppressWarnings({"unchecked", "rawtypes"}) - public static Operand mean( + + public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { - // Cannot use generics here because xf may change from TBool to TFloat32 - Operand xf; - if (x.type().equals(TBool.class)) { - xf = (Operand) tf.dtypes.cast(x, TFloat32.class); - } else { - xf = (Operand) x; - } if (axes == null) { - axes = (Operand) allAxes(tf, xf); + axes = (Operand) allAxes(tf, x); } - Operand theMean = tf.math.mean(xf, axes, Mean.keepDims(keepDims)); - return x.type().equals(TBool.class) ? tf.dtypes.cast(theMean, TBool.class) : theMean; + return tf.math.mean(x, axes, Mean.keepDims(keepDims)); + } + + /** + * Calculate the mean of the operand, along all axes and keepDims is false + * + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @return the mean of the operand containing floating point numbers + */ + public static Operand booleanMean(Ops tf, Operand x) { + return booleanMean(tf, x, null, false); } + + /** + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param the type of the axes. + * @return the mean of the operand, along the specified axes, containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x,Operand axes) { + return booleanMean(tf, x, axes, false); + } + + /** + * Calculates the mean of the boolean operand, alongside all axes. + * + * @param x the boolean Operand used to calculate the mean + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the axes + * @return the mean of elements of x containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x, boolean keepDims) { + return booleanMean(tf, x, null, keepDims); + } + + /** + * Calculates the mean of the boolean operand, alongside the specified axes. + * + * @param x the boolean Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the axes + * @return the mean of elements of x containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { + Operand xf = cast(tf, x, TFloat64.class); + return mean(tf, xf, axes, keepDims); + } + } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 1841c7ee238..236b3d9084d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -54,7 +54,7 @@ public String getSetOperation() { /** * Computes set difference of elements in last dimension of a and b with - * aMinusB set to true. + * aMinusB set to true/ * *

All but the last dimension of a and b must match * @@ -136,11 +136,6 @@ public static Operand setOperation( DenseToDenseSetOperation setOperationResult = tf.sparse.denseToDenseSetOperation( a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - - return tf.sparse.sparseToDense( - setOperationResult.resultIndices(), - setOperationResult.resultShape(), - setOperationResult.resultValues(), - cast(tf, tf.constant(0), a.type())); + return setOperationResult.resultValues(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index af4a89692d1..63d666f8640 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -72,7 +72,6 @@ private void testValid( testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); try (Tensor weightsTensor = tensors.get(0); Tensor valuesTensor = tensors.get(1)) { - Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession @@ -90,6 +89,7 @@ public void testValidScalar() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayF); Operand weights = tf.constant(5f); testValid(testSession, tf, weights, values, TFloat32.class); @@ -101,6 +101,7 @@ public void test1x1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayD); Operand weights = tf.constant(new double[][][] {{{5}}}); testValid(testSession, tf, weights, values, TFloat64.class); @@ -134,6 +135,7 @@ public void test1xNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -145,6 +147,7 @@ public void testNx1x1() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); testValid(testSession, tf, weights, values, TInt32.class); @@ -156,6 +159,7 @@ public void testNx1xN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); @@ -168,6 +172,7 @@ public void testNxNxN() { // no exception should be thrown try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java index eceff2797f8..5250c22d740 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java @@ -30,13 +30,13 @@ public void testSetIntersectionMultirow2() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); - int[][] expected = new int[][] {{1, 9}, {0, 0}}; - Shape expectedShape = Shape.of(2, 2); + Integer[] expected = new Integer[] {1, 9}; + Shape expectedShape = Shape.of(2); for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); - session.evaluate(cast(tf, tf.constant(expected), type), intersection); + Operand intersection = SetsOps.intersection(tf, aa, bb); + session.evaluate(expected, intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } @@ -50,23 +50,19 @@ public void testSetIntersectionDuplicates2d() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{1, 1, 3}}); - Operand b = tf.constant(new int[][] {{1, 1}}); - int[][] expected = {{1}}; - Shape expectedShape = Shape.of(1, 1); + Operand b = tf.constant(new int[][] {{1}}); + Integer[] expected = new Integer[] {1}; + Shape expectedShape = Shape.of(1); for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); Operand intersection = SetsOps.intersection(tf, aa, bb); - - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - + session.evaluate(expected, intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } - @Test - @SuppressWarnings({"unchecked", "rawtypes"}) public void testDenseSetDifferenceMultirow2d() { for (TestSession.Mode tfMode : tfModes) @@ -74,30 +70,24 @@ public void testDenseSetDifferenceMultirow2d() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); - + Integer[] expected = new Integer[] {5, 9, 3, 4, 5}; for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - int[][] expected = {{5, 9, 0}, {3, 4, 5}}; // a- b - Shape expectedShape = Shape.of(2, 3); Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(5L), tf.shape(intersection, TInt64.class)); // b - a - expected = new int[][] {{2, 6}, {1, 2}}; - expectedShape = Shape.of(2, 2); + expected = new Integer[] {2, 6, 1, 2}; intersection = SetsOps.difference(tf, aa, bb, false); - - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(4L), tf.shape(intersection, TInt64.class)); } } } - @Test - @SuppressWarnings({"unchecked", "rawtypes"}) public void testDenseUnionMultirow2d() { for (TestSession.Mode tfMode : tfModes) @@ -105,15 +95,15 @@ public void testDenseUnionMultirow2d() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); - int[][] expected = new int[][] {{5, 0}, {3, 4}}; + Integer[] expected = new Integer[] {1, 5, 9, 1, 2, 3, 4}; for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Shape expectedShape = Shape.of(2, 2); // a- b Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(cast(tf, tf.constant(expected), type), intersection); - session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + session.evaluate(expected, intersection); + session.evaluate(tf.constant(7L), tf.shape(intersection, TInt64.class)); + } } } From fb8aa65780ef66d611e2782780618551a3f9bdc8 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 18 Jan 2021 20:22:46 -0500 Subject: [PATCH 51/56] Fixed based on various PR comments. --- .../framework/metrics/BinaryCrossentropy.java | 2 +- .../framework/metrics/CategoricalCrossentropy.java | 4 ++-- .../org/tensorflow/framework/metrics/Metrics.java | 10 +++++----- .../metrics/SparseCategoricalCrossentropy.java | 12 ++++++------ .../framework/metrics/impl/MeanMetricWrapper.java | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index c339b977007..651a6fac0b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -41,7 +41,7 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index 7b8cf0054a4..c330ea88eaa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -48,7 +48,7 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +68,7 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index e2cd5e368c2..0169bc6b8bc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -64,19 +64,19 @@ public static Operand topKCategoricalA * @param tf the TensorFlow Ops * @param labels The ground truth values. * @param predictions The prediction values. - * @param axis The dimension along which the cosine similarity is computed. + * @param axes The dimensions along which the cosine similarity is computed. * @param the data type for the labels * @param the data type for the predictions and result * @return Cosine similarity value. */ public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axis) { + Ops tf, Operand labels, Operand predictions, int[] axes) { Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axis); + labelsNorm = l2Normalize(tf, labelsNorm, axes); - Operand predictionsNorm = l2Normalize(tf, predictions, axis); + Operand predictionsNorm = l2Normalize(tf, predictions, axes); Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); + return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 3fde8b2ecf6..2e01f722de6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -32,30 +32,30 @@ public class SparseCategoricalCrossentropy extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; - private final int axes; + private final int axis; /** * Creates a SparseCategoricalCrossentropy metric * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values or not. - * @param axes The dimension along which the entropy is computed. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result */ public SparseCategoricalCrossentropy( - Ops tf, String name, boolean fromLogits, int axes, long seed, Class type) { + Ops tf, String name, boolean fromLogits, int axis, long seed, Class type) { super(tf, name, seed, type); setLoss(this); this.fromLogits = fromLogits; - this.axes = axes; + this.axis = axis; } /** {@inheritDoc} */ @Override public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axes); + return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index e2f1345f356..17c209a8fed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -85,7 +85,7 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the predictions + * @param the datatype of the labels * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ From 02be1749ff0e48d7e71d79fa8183ec6263992963 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 18 Jan 2021 20:23:44 -0500 Subject: [PATCH 52/56] Deleted, no longer needed after change to Variable handling in Metrics. --- .../metrics/impl/MetricVariable.java | 125 ------------------ 1 file changed, 125 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java deleted file mode 100644 index 6b208c0d7bf..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricVariable.java +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import org.tensorflow.Operand; -import org.tensorflow.framework.initializers.Glorot; -import org.tensorflow.framework.initializers.Initializer; -import org.tensorflow.framework.initializers.VarianceScaling; -import org.tensorflow.framework.initializers.Zeros; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.family.TFloating; -import org.tensorflow.types.family.TIntegral; -import org.tensorflow.types.family.TNumber; - -/** - * Helper class that holds a metric variable - * - * @param the data type of the variable - */ -public class MetricVariable { - private final Variable variable; - private final Initializer initializer; - private final Ops tf; - private boolean initialized; - - /** - * Creates a Metric Variable - * - * @param tf the TensorFlow Ops - * @param variable the variable - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variable - */ - public MetricVariable(Ops tf, Variable variable, long seed, Class type) { - this(tf, variable, null, seed, type); - } - - /** - * Creates a Metric Variable - * - * @param tf the TensorFlow Ops - * @param variable the variable - * @param initializer the initializer for the variable, if null, then the default for floating - * point types is {@link org.tensorflow.framework.initializers.Glorot} with distribution - * {@link org.tensorflow.framework.initializers.VarianceScaling.Distribution#UNIFORM}, for - * other types the default initializer is {@link org.tensorflow.framework.initializers.Zeros} - * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and data type. - * @param type the type for the variable - * @throws IllegalArgumentException if the type does not inherit from TNumber and the initializer - * is null - */ - @SuppressWarnings("unchecked") - public MetricVariable( - Ops tf, Variable variable, Initializer initializer, long seed, Class type) { - this.tf = tf; - this.variable = variable; - - if (initializer == null) { - if (TFloating.class.isAssignableFrom(type)) { - //noinspection RedundantCast - this.initializer = - (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); - } else if (TIntegral.class.isAssignableFrom(type)) { - this.initializer = new Zeros<>(tf); - } else { - throw new IllegalArgumentException( - String.format("Type %s is not supported for metric variables", type.getSimpleName())); - } - } else { - this.initializer = initializer; - } - } - - /** - * Initializers the variable based on the initializer - * - * @return the initialized variable - */ - public Operand initialize() { - initialized = true; - return initializer.call(tf.constant(variable.shape()), variable.type()); - } - - /** - * Gets the variable - * - * @return the variable - */ - public Variable getVariable() { - return variable; - } - - /** - * Gets the initializer - * - * @return the initializer - */ - public Initializer getInitializer() { - return initializer; - } - - /** - * Gets the value of initialized - * - * @return the value of initialized - */ - public boolean isInitialized() { - return initialized; - } -} From 25b061a46b1b6c8523a113245922c202f6fd642e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 25 Jan 2021 16:52:03 -0800 Subject: [PATCH 53/56] Remove extra generics from op generation (#193) * Successfully remove extra type params, but it broke javadoc generation Signed-off-by: Ryan Nett * Generate covariant types Signed-off-by: Ryan Nett * Do generation Signed-off-by: Ryan Nett * Update help text. Signed-off-by: Ryan Nett * Fixes Signed-off-by: Ryan Nett --- .../tensorflow/op/core/GuaranteeConst.java | 81 ------------------- .../tensorflow/op/core/RefNextIteration.java | 75 ----------------- .../experimental/DummySeedGenerator.java | 69 ---------------- 3 files changed, 225 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java deleted file mode 100644 index aeab16c7c6c..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/GuaranteeConst.java +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TType; - -/** - * Gives a guarantee to the TF runtime that the input tensor is a constant. - *

- * The runtime is then free to make optimizations based on this. - *

- * Only accepts value typed tensors as inputs and rejects resource variable handles - * as input. - *

- * Returns the input tensor without modification. - * - * @param data type for {@code output()} output - */ -@Operator -public final class GuaranteeConst extends RawOp implements Operand { - - /** - * Factory method to create a class wrapping a new GuaranteeConst operation. - * - * @param scope current scope - * @param input - * @return a new instance of GuaranteeConst - */ - @Endpoint(describeByClass = true) - public static GuaranteeConst create(Scope scope, Operand input) { - OperationBuilder opBuilder = scope.env().opBuilder("GuaranteeConst", scope.makeOpName("GuaranteeConst")); - opBuilder.addInput(input.asOutput()); - opBuilder = scope.apply(opBuilder); - return new GuaranteeConst(opBuilder.build()); - } - - /** - */ - public Output output() { - return output; - } - - @Override - public Output asOutput() { - return output; - } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "GuaranteeConst"; - - private Output output; - - private GuaranteeConst(Operation operation) { - super(operation); - int outputIdx = 0; - output = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java deleted file mode 100644 index f3f6e374590..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/RefNextIteration.java +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.core; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TType; - -/** - * Makes its input available to the next iteration. - * - * @param data type for {@code output()} output - */ -@Operator -public final class RefNextIteration extends RawOp implements Operand { - - /** - * Factory method to create a class wrapping a new RefNextIteration operation. - * - * @param scope current scope - * @param data The tensor to be made available to the next iteration. - * @return a new instance of RefNextIteration - */ - @Endpoint(describeByClass = true) - public static RefNextIteration create(Scope scope, Operand data) { - OperationBuilder opBuilder = scope.env().opBuilder("RefNextIteration", scope.makeOpName("RefNextIteration")); - opBuilder.addInput(data.asOutput()); - opBuilder = scope.apply(opBuilder); - return new RefNextIteration(opBuilder.build()); - } - - /** - * The same tensor as `data`. - */ - public Output output() { - return output; - } - - @Override - public Output asOutput() { - return output; - } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "RefNextIteration"; - - private Output output; - - private RefNextIteration(Operation operation) { - super(operation); - int outputIdx = 0; - output = operation.output(outputIdx++); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java deleted file mode 100644 index 8c60fc6350f..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ - -// This class has been generated, DO NOT EDIT! - -package org.tensorflow.op.random.experimental; - -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.OperationBuilder; -import org.tensorflow.Output; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.family.TType; - -/** - */ -public final class DummySeedGenerator extends RawOp implements Operand { - - /** - * Factory method to create a class wrapping a new DummySeedGenerator operation. - * - * @param scope current scope - * @return a new instance of DummySeedGenerator - */ - @Endpoint(describeByClass = true) - public static DummySeedGenerator create(Scope scope) { - OperationBuilder opBuilder = scope.env().opBuilder("DummySeedGenerator", scope.makeOpName("DummySeedGenerator")); - opBuilder = scope.apply(opBuilder); - return new DummySeedGenerator(opBuilder.build()); - } - - /** - */ - public Output handle() { - return handle; - } - - @Override - @SuppressWarnings("unchecked") - public Output asOutput() { - return (Output) handle; - } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "DummySeedGenerator"; - - private Output handle; - - private DummySeedGenerator(Operation operation) { - super(operation); - int outputIdx = 0; - handle = operation.output(outputIdx++); - } -} From a1c518716e47128f383685541395874d27f50ef1 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 30 Jan 2021 15:20:10 -0500 Subject: [PATCH 54/56] Fix SetOps to properly convert sparse tensor to dense tensor using tf.sparse.sparseToDense with the output of tf.sparse.denseToDenseSetOperation --- .../framework/metrics/impl/SetsOps.java | 9 +++- .../framework/metrics/impl/SetsOpsTest.java | 46 +++++++++++-------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 236b3d9084d..1841c7ee238 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -54,7 +54,7 @@ public String getSetOperation() { /** * Computes set difference of elements in last dimension of a and b with - * aMinusB set to true/ + * aMinusB set to true. * *

All but the last dimension of a and b must match * @@ -136,6 +136,11 @@ public static Operand setOperation( DenseToDenseSetOperation setOperationResult = tf.sparse.denseToDenseSetOperation( a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - return setOperationResult.resultValues(); + + return tf.sparse.sparseToDense( + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + cast(tf, tf.constant(0), a.type())); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java index 5250c22d740..eceff2797f8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java @@ -30,13 +30,13 @@ public void testSetIntersectionMultirow2() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); - Integer[] expected = new Integer[] {1, 9}; - Shape expectedShape = Shape.of(2); + int[][] expected = new int[][] {{1, 9}, {0, 0}}; + Shape expectedShape = Shape.of(2, 2); for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); - session.evaluate(expected, intersection); + Operand intersection = SetsOps.intersection(tf, aa, bb); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } @@ -50,19 +50,23 @@ public void testSetIntersectionDuplicates2d() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{1, 1, 3}}); - Operand b = tf.constant(new int[][] {{1}}); - Integer[] expected = new Integer[] {1}; - Shape expectedShape = Shape.of(1); + Operand b = tf.constant(new int[][] {{1, 1}}); + int[][] expected = {{1}}; + Shape expectedShape = Shape.of(1, 1); for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); Operand intersection = SetsOps.intersection(tf, aa, bb); - session.evaluate(expected, intersection); + + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) public void testDenseSetDifferenceMultirow2d() { for (TestSession.Mode tfMode : tfModes) @@ -70,24 +74,30 @@ public void testDenseSetDifferenceMultirow2d() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); - Integer[] expected = new Integer[] {5, 9, 3, 4, 5}; + for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); + int[][] expected = {{5, 9, 0}, {3, 4, 5}}; // a- b + Shape expectedShape = Shape.of(2, 3); Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(expected, intersection); - session.evaluate(tf.constant(5L), tf.shape(intersection, TInt64.class)); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); // b - a - expected = new Integer[] {2, 6, 1, 2}; + expected = new int[][] {{2, 6}, {1, 2}}; + expectedShape = Shape.of(2, 2); intersection = SetsOps.difference(tf, aa, bb, false); - session.evaluate(expected, intersection); - session.evaluate(tf.constant(4L), tf.shape(intersection, TInt64.class)); + + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) public void testDenseUnionMultirow2d() { for (TestSession.Mode tfMode : tfModes) @@ -95,15 +105,15 @@ public void testDenseUnionMultirow2d() { Ops tf = session.getTF(); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); - Integer[] expected = new Integer[] {1, 5, 9, 1, 2, 3, 4}; + int[][] expected = new int[][] {{5, 0}, {3, 4}}; for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); + Shape expectedShape = Shape.of(2, 2); // a- b Operand intersection = SetsOps.difference(tf, aa, bb); - session.evaluate(expected, intersection); - session.evaluate(tf.constant(7L), tf.shape(intersection, TInt64.class)); - + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } } } From b125294698a8500a96f3f488ff3b89fab94297de Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 11:16:37 -0500 Subject: [PATCH 55/56] Simplify generic parameters across losses and metrics. --- .../framework/losses/BinaryCrossentropy.java | 5 +- .../losses/CategoricalCrossentropy.java | 5 +- .../framework/losses/CategoricalHinge.java | 4 +- .../framework/losses/CosineSimilarity.java | 67 +++++++++++++--- .../tensorflow/framework/losses/Hinge.java | 4 +- .../tensorflow/framework/losses/Huber.java | 4 +- .../framework/losses/KLDivergence.java | 4 +- .../tensorflow/framework/losses/LogCosh.java | 4 +- .../org/tensorflow/framework/losses/Loss.java | 8 +- .../tensorflow/framework/losses/Losses.java | 73 ++++++++--------- .../framework/losses/MeanAbsoluteError.java | 4 +- .../losses/MeanAbsolutePercentageError.java | 4 +- .../framework/losses/MeanSquaredError.java | 4 +- .../losses/MeanSquaredLogarithmicError.java | 4 +- .../tensorflow/framework/losses/Poisson.java | 4 +- .../losses/SparseCategoricalCrossentropy.java | 5 +- .../framework/losses/SquaredHinge.java | 5 +- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../metrics/CategoricalCrossentropy.java | 13 +-- .../framework/metrics/CategoricalHinge.java | 5 +- .../framework/metrics/CosineSimilarity.java | 8 +- .../tensorflow/framework/metrics/Hinge.java | 5 +- .../framework/metrics/KLDivergence.java | 5 +- .../framework/metrics/LogCoshError.java | 5 +- .../tensorflow/framework/metrics/Mean.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 5 +- .../metrics/MeanAbsolutePercentageError.java | 7 +- .../framework/metrics/MeanSquaredError.java | 5 +- .../metrics/MeanSquaredLogarithmicError.java | 7 +- .../tensorflow/framework/metrics/Metric.java | 32 ++++---- .../tensorflow/framework/metrics/Metrics.java | 80 +------------------ .../tensorflow/framework/metrics/Poisson.java | 6 +- .../SparseCategoricalCrossentropy.java | 9 +-- .../framework/metrics/SquaredHinge.java | 5 +- .../framework/metrics/impl/LossMetric.java | 3 +- .../metrics/impl/MeanMetricWrapper.java | 21 +++-- .../framework/metrics/impl/MetricsHelper.java | 66 +++++++-------- .../framework/metrics/impl/Reduce.java | 68 ++++++++-------- .../metrics/BinaryCrossentropyTest.java | 10 +-- .../metrics/CategoricalCrossentropyTest.java | 10 +-- .../metrics/CategoricalHingeTest.java | 4 +- .../metrics/CosineSimilarityTest.java | 6 +- .../framework/metrics/HingeTest.java | 4 +- .../framework/metrics/KLDivergenceTest.java | 4 +- .../framework/metrics/LogCoshErrorTest.java | 4 +- .../metrics/MeanAbsoluteErrorTest.java | 4 +- .../MeanAbsolutePercentageErrorTest.java | 4 +- .../metrics/MeanSquaredErrorTest.java | 4 +- .../MeanSquaredLogarithmicErrorTest.java | 4 +- .../framework/metrics/PoissonTest.java | 4 +- .../SparseCategoricalCrossentropyTest.java | 6 +- .../framework/metrics/SquaredHingeTest.java | 4 +- 52 files changed, 293 insertions(+), 354 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index c7edfcca24e..3417c07372a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -202,13 +202,12 @@ public BinaryCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 363291fa5cc..035af9589ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -242,13 +242,12 @@ public CategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index f592c19f8bb..4e9133d8835 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 137c7025c04..0a18d93caf3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -22,12 +22,13 @@ /** * Computes the cosine similarity between labels and predictions. * - *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 - * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to - * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you - * try to maximize the proximity between predictions and targets. If either labels or predictions is - * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and - * targets. + *

Note that it is a number between -1 and 1. When it is a negative + * number between -1 and 0, 0 indicates orthogonality and + * values closer to -1indicate greater similarity. The values closer to 1 + * indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels or + * predictions is a zero vector, cosine similarity will be 0 regardless of + * the proximity between predictions and targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) * @@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; - private final int axis; + private final int[] axis; /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis @@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) { this(tf, null, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a + * Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, int[] axis) { + + this(tf, null, axis, DEFAULT_REDUCTION); + } /** * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} @@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) { this(tf, name, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, String name, int[] axis) { + + this(tf, name, axis, DEFAULT_REDUCTION); + } + /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an * axis of {@link #DEFAULT_AXIS} @@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) { */ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + this(tf, null, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + this(tf, null, axis, reduction); } @@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { * @param reduction Type of Reduction to apply to the loss. */ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { + this(tf, name, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { super(tf, name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + losses = tf.math.neg(losses); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 88b4a7aa056..37e7e367b9b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -121,8 +121,8 @@ public Hinge(Ops tf, String name, Reduction reduction) { * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") Operand tLabels = predictions.type() == labels.type() ? (Operand)labels : cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 6d3e3f0c2ac..e8de632eb09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -130,8 +130,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 8cf3db8d518..b3c0206b409 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 1669669a768..812260d9881 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -105,8 +105,8 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index ae33d5dfa37..0f9b183f38c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -62,10 +62,9 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels * @param predictions the predictions * @param The data type of the predictions and loss. - * @param The data type of the labels. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return call(labels, predictions, null); } @@ -82,11 +81,10 @@ public Operand call(Operand labels, * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 0d25bd5e7e2..a5ced3d1df8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -48,11 +48,10 @@ public class Losses { * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute error */ - public static Operand meanAbsoluteError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsoluteError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -70,11 +69,10 @@ public static Operand meanAbsoluteErro * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared error */ - public static Operand meanSquaredError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -91,11 +89,10 @@ public static Operand meanSquaredError * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute percentage error */ - public static Operand meanAbsolutePercentageError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsolutePercentageError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -118,11 +115,10 @@ public static Operand meanAbsolutePerc * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared logarithmic percentage error */ - public static Operand meanSquaredLogarithmicError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredLogarithmicError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -152,8 +148,8 @@ public static Operand meanSquaredLogar * @param the data type of the predictions and labels * @return the binary crossentropy loss. */ - public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + public static Operand binaryCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -181,7 +177,7 @@ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible + /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { // TODO - this does not work // TODO output = backtrackIdentity(output); @@ -225,9 +221,9 @@ private static Operand binaryCrossentropyHelper( * @param the data type of the predictions and labels * @return the categorical crossentropy loss. */ - public static Operand categoricalCrossentropy( + public static Operand categoricalCrossentropy( Ops tf, - Operand labels, + Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing, @@ -283,8 +279,8 @@ public static Operand categoricalCross * @param the data type of the predictions and labels * @return the categorical hinge loss */ - public static Operand categoricalHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand categoricalHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -329,8 +325,8 @@ public static Operand categoricalHinge * @param the data type of the predictions and labels * @return the cosine similarity loss */ - public static Operand cosineSimilarity( - Ops tf, Operand labels, Operand predictions, int axis) { + public static Operand cosineSimilarity( + Ops tf, Operand labels, Operand predictions, int[] axis) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -339,8 +335,7 @@ public static Operand cosineSimilarity tLabels = l2Normalize(tf, tLabels, axis); predictions = l2Normalize(tf, predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); - Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.neg(sum); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } /** @@ -355,8 +350,8 @@ public static Operand cosineSimilarity * @param the data type of the predictions and labels * @return the hinge loss */ - public static Operand hinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand hinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -391,8 +386,8 @@ public static Operand hinge( * @param the data type of the predictions and labels * @return the Huber loss */ - public static Operand huber( - Ops tf, Operand labels, Operand predictions, float delta) { + public static Operand huber( + Ops tf, Operand labels, Operand predictions, float delta) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -422,8 +417,8 @@ public static Operand huber( * @see Kullback?Leibler * divergence */ - public static Operand kullbackLeiblerDivergence( - Ops tf, Operand labels, Operand predictions) { + public static Operand kullbackLeiblerDivergence( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -452,8 +447,8 @@ public static Operand kullbackLeiblerD * @param the data type of the predictions and labels * @return the hyperbolic cosine divergence loss */ - public static Operand logCosh( - Ops tf, Operand labels, Operand predictions) { + public static Operand logCosh( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -480,8 +475,8 @@ public static Operand logCosh( * @param the data type of the predictions and labels * @return the Poisson loss */ - public static Operand poisson( - Ops tf, Operand labels, Operand predictions) { + public static Operand poisson( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -507,8 +502,8 @@ public static Operand poisson( * @param the data type of the predictions and labels * @return the sparse categorical crossentropy loss */ - public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + public static Operand sparseCategoricalCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -553,7 +548,7 @@ public static Operand sparseCategorica int labelsRank = labelsShape.numDimensions(); boolean updateShape = labelsRank != predictionsRank - 1; - if (updateShape) { // TODO check to see if this is right + if (updateShape) { Shape newShape = labelsShape.take(labelsRank - 1); iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = @@ -584,8 +579,8 @@ public static Operand sparseCategorica * @param the data type of the predictions and labels * @return the squared hinge loss */ - public static Operand squaredHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand squaredHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -651,7 +646,7 @@ private static Operand smoothCategoricalLabels( * @param axis Dimension along which to normalize. * @return the normalized values based on L2 norm */ - public static Operand l2Normalize(Ops tf, Operand x, int axis) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index a2d5d5f8efc..594de1e1448 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -95,8 +95,8 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 49133df610b..275a2e136a0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -95,8 +95,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 2a6c2be885e..31df3e70e0b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -95,8 +95,8 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 2604e226b81..bef990d22bc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -95,8 +95,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index c43be4f2821..9cf38aa0380 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -104,8 +104,8 @@ public Poisson(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.poisson(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index ea765e6f8fd..3ec33113e89 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -190,13 +190,12 @@ public SparseCategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 4ad4c1c726c..968624db202 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -117,13 +117,12 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") Operand tLabels = predictions.type() == labels.type() ? (Operand)labels : cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 651a6fac0b0..abd2dcbbf40 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,17 +21,18 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * *

This is the crossentropy metric class to be used when there are only two label classes (0 and * 1). * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -60,7 +61,7 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index c330ea88eaa..be43f34b92e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -30,11 +30,10 @@ * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class CategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -48,7 +47,8 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to + * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +68,8 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -98,7 +99,7 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalCrossentropy( getTF(), labels, predictions, fromLogits, labelSmoothing, axis); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 2741a36edb6..c70f2d8643b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -24,10 +24,9 @@ /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper +public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 458de092bec..5abbd095420 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; @@ -23,10 +24,9 @@ /** * A metric that computes the cosine similarity metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; @@ -76,8 +76,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Metrics.cosineProximity(getTF(), labels, predictions, axis); + return Losses.cosineSimilarity(getTF(), labels, predictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index baf9ad8ab7d..e0aced6fa3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -24,10 +24,9 @@ /** * A metric that computes the hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper +public class Hinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.hinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index efcbbcbb7f0..fa09f2784b5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -25,10 +25,9 @@ * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** @@ -47,7 +46,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 3df8505d54b..c43551a6948 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -25,10 +25,9 @@ * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper +public class LogCoshError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -47,7 +46,7 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.logCosh(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index de1f5a5629e..8902b329bcc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -21,10 +21,9 @@ /** * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * - * @param The data type for the metric values * @param The data type for the metric result */ -public class Mean extends Reduce { +public class Mean extends Reduce { /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index e27676932ff..d343ec77ab0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -24,10 +24,9 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper +public class MeanAbsoluteError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +45,7 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsoluteError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 84fa9b627b2..dd7d151260b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -24,11 +24,10 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c7edd6ebe93..c2bef576b30 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -24,10 +24,9 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper +public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 199b6e0e114..c1cf4ca6c9a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -24,11 +24,10 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index bbb2aa73da2..8ab21c58218 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -25,10 +25,9 @@ /** * Base class for Metrics * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Metric { +public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; @@ -75,10 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state - * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -90,13 +89,13 @@ public List updateStateList(Operand values, Operand the data type for the labels - * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -105,10 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -119,12 +118,12 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the labels - * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, + Operand predictions, + Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -149,10 +148,9 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies - * @param the data type for the sampleWeights. */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 0169bc6b8bc..95b74bf1eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -17,7 +17,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; @@ -46,89 +45,14 @@ public class Metrics { * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. * @param the data type for the predictions and results - * @param the data type ofr the labels. * @return the Operand for the Top K categorical accuracy value. */ - public static Operand topKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, long k) { + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); return CastHelper.cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } - - /** - * Computes the cosine similarity between labels and predictions. - * - * @param tf the TensorFlow Ops - * @param labels The ground truth values. - * @param predictions The prediction values. - * @param axes The dimensions along which the cosine similarity is computed. - * @param the data type for the labels - * @param the data type for the predictions and result - * @return Cosine similarity value. - */ - public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axes) { - Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axes); - - Operand predictionsNorm = l2Normalize(tf, predictions, axes); - Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); - } - - /** - * Normalizes along dimension axis using an L2 norm with an epsilon of {@link - * #L2_NORM_EPSILON}. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param The data type for x. - * @return the normalized values of x. - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { - return l2Normalize(tf, x, axes, L2_NORM_EPSILON); - } - - /** - * Normalizes along dimension axis using an L2 norm. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the - * divisor if norm < sqrt(epsilon). - * @param The data type for the values. - * @return the normalized values of x. - */ - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes, float epsilon) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); - Operand y = - tf.math.rsqrt( - tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); - return tf.math.mul(x, y); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 75a2031fbb5..af50b103a60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,10 +24,10 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param the data type for the predictions. + * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper +public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +46,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 2e01f722de6..a0c016b70b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -24,12 +24,11 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. - * - * @param the data type for the predictions. + *\ * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final int axis; @@ -55,7 +54,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 430dbbcc229..bd331a85eda 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -24,10 +24,9 @@ /** * A metric that computes the squared hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index b7b87d313aa..70bb8133698 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -29,8 +29,7 @@ public interface LossMetric { * * @param labels the truth values or labels * @param predictions the predictions - * @param The data type of the labels. * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 17c209a8fed..9a532a0294f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -17,13 +17,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.metrics.Mean; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. @@ -32,10 +33,9 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class MeanMetricWrapper extends Mean { +public class MeanMetricWrapper extends Mean { /** The loss function interface */ protected LossMetric loss; @@ -85,22 +85,21 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList( - CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList(cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index ad8ff58e417..6cc089fce6d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -21,12 +21,10 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; @@ -57,13 +55,13 @@ public class MetricsHelper { * @param values the values to which weights are applied. * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values - * @param the type of Operand + * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") - public static Op assertBroadcastable( - Ops tf, Operand sampleWeights, Operand values) { + public static Op assertBroadcastable( + Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match @@ -129,7 +127,7 @@ public static Op assertBroadcastable( // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. - Operand reshapedWeights = + Operand reshapedWeights = tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); weightsShape = tf.shape(reshapedWeights); weightsRank = tf.rank(reshapedWeights); @@ -237,11 +235,10 @@ public static Operand mean(Ops tf, Operand x) { * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axes. * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axes) { + public static Operand mean( + Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } @@ -257,31 +254,27 @@ public static Operand mean( * @param the type of the operand * @return the mean of elements of x. */ - public static Operand mean( - Ops tf, Operand x, boolean keepDims) { + public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } - - /** * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @param the data type of the Operand - * @param the data type of the axes * @return the mean of elements of x. */ - - public static Operand mean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { if (axes == null) { - axes = (Operand) allAxes(tf, x); + axes = allAxes(tf, x); } return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } @@ -294,7 +287,7 @@ public static Operand mean( * @param x the Operand used to calculate the mean * @return the mean of the operand containing floating point numbers */ - public static Operand booleanMean(Ops tf, Operand x) { + public static Operand booleanMean(Ops tf, Operand x) { return booleanMean(tf, x, null, false); } @@ -305,11 +298,10 @@ public static Operand booleanMean(Ops tf, Operand x) { * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param the type of the axes. * @return the mean of the operand, along the specified axes, containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x,Operand axes) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes) { return booleanMean(tf, x, axes, false); } @@ -317,14 +309,13 @@ public static Operand booleanMean( * Calculates the mean of the boolean operand, alongside all axes. * * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, boolean keepDims) { + public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); } @@ -333,16 +324,15 @@ public static Operand booleanMean( * * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { Operand xf = cast(tf, x, TFloat64.class); return mean(tf, xf, axes, keepDims); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 8e48cb4e573..2a26967b9f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -19,7 +19,6 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -29,13 +28,14 @@ import java.util.ArrayList; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Encapsulates metrics that perform a reduce operation on the metric values. * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Reduce extends Metric { +public abstract class Reduce extends Metric { public static final String TOTAL = "total"; public static final String COUNT = "count"; protected final MetricReduction reduction; @@ -45,8 +45,10 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values. - * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + /** + * the variable that holds the count of the metric values. For {@link + * MetricReduction#WEIGHTED_MEAN}, this count may be weighted + */ protected Variable count; /** @@ -95,12 +97,10 @@ private void setupVars() { public Op resetStates() { List controls = new ArrayList<>(); if (total != null) { - controls.add( - getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + controls.add(getTF().assign(total, cast(getTF(), getTF().constant(0), total.type()))); } if (count != null) { - controls.add( - getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + controls.add(getTF().assign(count, cast(getTF(), getTF().constant(0), count.type()))); } return getTF().withControlDependencies(controls).noOp(); } @@ -115,67 +115,67 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); } + Ops tf = getTF(); List updateOperations = new ArrayList<>(); // cast everything to match the variables - Operand lSampleWeights = null; - Operand lValues = values; + Operand tSampleWeights = null; + Operand tValues = cast(tf, values, getResultType()); if (sampleWeights != null) { - lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); - lValues = tuple.getTarget(); - lSampleWeights = tuple.getSampleWeights(); + tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); try { - lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples - int valuesRank = lValues.shape().numDimensions(); - int weightsRank = lSampleWeights.shape().numDimensions(); + int valuesRank = tValues.shape().numDimensions(); + int weightsRank = tSampleWeights.shape().numDimensions(); int numAxes = Math.min(0, valuesRank - weightsRank); if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + tValues = getTF().reduceSum(tValues, getTF().constant(axes)); } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + tValues = getTF().math.mean(tValues, getTF().constant(axes)); } } } - lValues = getTF().math.mul(lValues, lSampleWeights); + tValues = getTF().math.mul(tValues, tSampleWeights); } - Operand weightedValueSum = - getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); break; case WEIGHTED_MEAN: - if (lSampleWeights == null) { - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + if (tSampleWeights == null) { + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); } else { numValues = - CastHelper.cast( + cast( getTF(), getTF() - .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), resultType); } break; @@ -202,7 +202,7 @@ public Operand result() { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + fResult = getTF().math.divNoNan(total, cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 7ceedded018..be46bb5c282 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -32,7 +32,7 @@ class BinaryCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -77,7 +77,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 0, 1, 0}; @@ -102,7 +102,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -128,7 +128,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java index 2b4a1d75467..34fc3eef884 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -31,7 +31,7 @@ class CategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -104,7 +104,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 0, 0, 1}; @@ -129,7 +129,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java index 87248d95e48..78b25a21b60 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -31,7 +31,7 @@ class CategoricalHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { @@ -64,7 +64,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index a9721ef2f8f..18410416c42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -31,7 +31,7 @@ class CosineSimilarityTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -80,7 +80,7 @@ public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); int axis = 1; - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 6af5fed4889..a9bd5fac76e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,7 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 28020c0fa1c..267578a492c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -31,7 +31,7 @@ class KLDivergenceTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java index 31c043e0473..1b5b8fb7d49 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -32,7 +32,7 @@ class LogCoshErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -56,7 +56,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java index 73241ecbe9f..984895f2ad9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -32,7 +32,7 @@ class MeanAbsoluteErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -74,7 +74,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java index 4c92844b217..0b9e7f6b538 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -34,7 +34,7 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -76,7 +76,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java index 0b760213015..e42052a9ef1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -33,7 +33,7 @@ class MeanSquaredErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); @@ -70,7 +70,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java index 098a5cb9725..e68d63b8778 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -32,7 +32,7 @@ class MeanSquaredLogarithmicErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -69,7 +69,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index cf3c3e44719..75d9ef93168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -32,7 +32,7 @@ class PoissonTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 87af1bd8448..8e1aaea0a8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -32,7 +32,7 @@ class SparseCategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -56,7 +56,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -105,7 +105,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java index e3376c224f3..2c80b3451ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -32,7 +32,7 @@ class SquaredHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { @@ -61,7 +61,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { From 78366de1a73bb72612ddc40736accd6c763a4f1a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 14:42:48 -0500 Subject: [PATCH 56/56] Reformat code --- .../annotations/org/tensorflow/op/Ops.java | 6 +- .../losses/CategoricalCrossentropy.java | 36 ++++++------ .../framework/losses/CategoricalHinge.java | 4 +- .../tensorflow/framework/losses/Hinge.java | 20 ++++--- .../tensorflow/framework/losses/Huber.java | 2 +- .../framework/losses/KLDivergence.java | 2 +- .../tensorflow/framework/losses/LogCosh.java | 2 +- .../org/tensorflow/framework/losses/Loss.java | 5 +- .../tensorflow/framework/losses/Losses.java | 28 +++++++--- .../framework/losses/MeanAbsoluteError.java | 2 +- .../losses/MeanAbsolutePercentageError.java | 2 +- .../framework/losses/MeanSquaredError.java | 2 +- .../losses/MeanSquaredLogarithmicError.java | 2 +- .../losses/SparseCategoricalCrossentropy.java | 34 +++++++----- .../framework/losses/SquaredHinge.java | 18 +++--- .../framework/losses/impl/LossTuple.java | 2 +- .../framework/losses/impl/LossesHelper.java | 26 ++++----- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../framework/metrics/CategoricalHinge.java | 4 +- .../framework/metrics/CosineSimilarity.java | 2 +- .../tensorflow/framework/metrics/Hinge.java | 3 +- .../framework/metrics/KLDivergence.java | 5 +- .../framework/metrics/LogCoshError.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 2 +- .../metrics/MeanAbsolutePercentageError.java | 6 +- .../framework/metrics/MeanSquaredError.java | 4 +- .../metrics/MeanSquaredLogarithmicError.java | 6 +- .../tensorflow/framework/metrics/Poisson.java | 6 +- .../SparseCategoricalCrossentropy.java | 13 +++-- .../framework/metrics/SquaredHinge.java | 5 +- .../framework/metrics/impl/LossMetric.java | 2 +- .../framework/metrics/impl/SetsOps.java | 55 ++++++++++--------- .../framework/metrics/HingeTest.java | 6 +- .../framework/metrics/PoissonTest.java | 3 +- .../SparseCategoricalCrossentropyTest.java | 2 +- 35 files changed, 174 insertions(+), 155 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..007ee9d0d42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 035af9589ae..5aac163c1e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. x=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means + * that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -199,13 +202,14 @@ public CategoricalCrossentropy( * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" - * and axis=1 corresponds to data format "Channels First". - * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} + * and axis=1 corresponds to data format "Channels First". {@link + * Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 4e9133d8835..73837ed1756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1.

+ *

labels values are expected to be 0 or 1. * *

Standalone usage: * @@ -100,7 +100,7 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 37e7e367b9b..db3569441ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,15 +18,16 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0)

. + *

loss = maximum(1 - labels * predictions, 0). * - *

labels values are expected to be -1 or 1. - * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -124,13 +125,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index e8de632eb09..665a9ac157d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -131,7 +131,7 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index b3c0206b409..2aa1f72092b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -100,7 +100,7 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 812260d9881..78325713e3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -106,7 +106,7 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 0f9b183f38c..cdd35d28aba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -64,7 +64,8 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param The data type of the predictions and loss. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { return call(labels, predictions, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index a5ced3d1df8..2222ebb41f8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -102,8 +102,10 @@ public static Operand meanAbsolutePercentageError( tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum( + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul( + cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -149,7 +151,11 @@ public static Operand meanSquaredLogarithmicError( * @return the binary crossentropy loss. */ public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -214,9 +220,10 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. @@ -503,7 +510,11 @@ public static Operand poisson( * @return the sparse categorical crossentropy loss */ public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -650,8 +661,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 594de1e1448..03a3cf70110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -96,7 +96,7 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 275a2e136a0..6c5242df4f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -96,7 +96,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 31df3e70e0b..f975db55c44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -96,7 +96,7 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index bef990d22bc..11b8e157e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -96,7 +96,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 3ec33113e89..d04cc67d5d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -79,7 +80,8 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ @@ -88,8 +90,8 @@ public SparseCategoricalCrossentropy(Ops tf) { } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -122,8 +124,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and - * fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -135,7 +137,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -176,9 +179,10 @@ public SparseCategoricalCrossentropy( /** * Generates an Operand the calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. @@ -200,12 +204,12 @@ public Operand call( if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesHelper.rangeCheck( - getTF(), - "predictions range check [0-1]", - predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + LossesHelper.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 968624db202..dadbdb3b95e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -25,8 +26,8 @@ * *

loss = square(maximum(1 - labels * predictions, 0)) * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be - * converted to -1 or 1. + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -107,7 +108,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor @@ -124,13 +125,16 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 2104937a979..f811549fbca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return labels, target, and sampleWeights + * A helper class for loss methods to return labels, target, and sampleWeights * * @param the data type of the LossTuple entries. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 10067db91ba..66bdd839f09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -32,8 +32,9 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to - * TensorFlow Java. These methods should not be used outside of the losses and metrics packages. + * These are helper methods for Losses and Metrics and will be module private when Java modularity + * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics + * packages. */ public class LossesHelper { @@ -42,10 +43,10 @@ public class LossesHelper { * *

    *
  1. Squeezes last dim of predictions or labels if their rank - * differs by 1 (using {@link #removeSqueezableDimensions}).
  2. + * differs by 1 (using {@link #removeSqueezableDimensions}). *
  3. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from * the new rank of predictions. If sampleWeight is scalar, it is - * kept scalar.
  4. + * kept scalar. *
* * @param tf the TensorFlow Ops @@ -77,12 +78,13 @@ public static LossTuple squeezeOrExpandDimensions( * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * * prediction. - * @return LossTuple of predictions, labels and sampleWeight. - * Each of them possibly has the last dimension squeezed, sampleWeight could be - * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are - * returned. + * @return LossTuple of predictions, labels and sampleWeight + * . Each of them possibly has the last dimension squeezed, sampleWeight + * could be extended by one dimension. If sampleWeight is null, only the possibly + * shape modified predictions and labels are returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -298,8 +300,7 @@ private static Operand reduceWeightedLoss( public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.type())); + return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); } /** @@ -383,8 +384,7 @@ public static Operand rangeCheck( */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { - Operand flatValues = - tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); + Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().shape().size(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index abd2dcbbf40..d8bb2a41116 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,8 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -31,8 +29,8 @@ * * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -42,7 +40,8 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index c70f2d8643b..4800fc43c49 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result */ -public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper +public class CategoricalHinge extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 5abbd095420..3ae67072955 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> +public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index e0aced6fa3e..3b84b81e071 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper - implements LossMetric { +public class Hinge extends MeanMetricWrapper implements LossMetric { /** * Creates a Hinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index fa09f2784b5..f631f562e1d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper - implements LossMetric { +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** * Creates a KLDivergence metric @@ -46,7 +45,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index c43551a6948..046937e228b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper< T> - implements LossMetric { +public class LogCoshError extends MeanMetricWrapper implements LossMetric { /** * Creates a LogCoshError metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index d343ec77ab0..977f61648a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper< T> +public class MeanAbsoluteError extends MeanMetricWrapper implements LossMetric { /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index dd7d151260b..bad5255969a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c2bef576b30..5b0d9ec43b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> +public class MeanSquaredError extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index c1cf4ca6c9a..35044fee956 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index af50b103a60..700099d3375 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,11 +24,9 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param The data type for the metric result. */ -public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> - implements LossMetric { +public class Poisson extends MeanMetricWrapper implements LossMetric { /** * Creates a Poisson metric @@ -46,7 +44,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a0c016b70b3..aa7ca316378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -23,12 +23,12 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. - *\ + * predicted labels. \ + * * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; @@ -38,7 +38,8 @@ public class SparseCategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -54,7 +55,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index bd331a85eda..01f4a403f84 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** * Creates a SquaredHinge metric @@ -45,7 +44,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 70bb8133698..037d634cd4a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 1841c7ee238..467dea19b57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -25,33 +25,6 @@ /** Implementation of set operations */ public class SetsOps { - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } - /** * Computes set difference of elements in last dimension of a and b with * aMinusB set to true. @@ -69,6 +42,7 @@ public String getSetOperation() { public static Operand difference(Ops tf, Operand a, Operand b) { return difference(tf, a, b, true); } + /** * Computes set difference of elements in last dimension of a and b. * @@ -143,4 +117,31 @@ public static Operand setOperation( setOperationResult.resultValues(), cast(tf, tf.constant(0), a.type())); } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index a9bd5fac76e..90531d21fde 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,8 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -55,8 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index 75d9ef93168..5631bac15ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 8e1aaea0a8f..0aece8c8ac9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); session.run(instance.resetStates());