Skip to content

Generic cleanup Metrics and Losses #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ public final class Ops {

public final SignalOps signal;

public final TrainOps train;

public final QuantizationOps quantization;

public final TrainOps train;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -370,8 +370,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
* VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter.
* <p>For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM}
* for the distribution parameter.
* <p></p>
*
* @param <T> The TType for the call operation
* @see VarianceScaling.Distribution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
* VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter.
* <p>For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM}
* for the distribution parameter.
* <p></p>
*
* @param <T> The TType for the call operation
* @see <a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,12 @@ public BinaryCrossentropy(
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param <T> The data type of the predictions, sampleWeights and loss.
* @param <U> The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> lPredictions;
if (!fromLogits) {
// add predictions range check for 0 - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) {
*
* @param tf the TensorFlow Ops
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
* <code>0.9</code> for label <code>1</code>
*/
public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) {
this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
}

/**
* Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT},
* and a channel axis of {@link #DEFAULT_AXIS}
* Creates a categorical cross entropy Loss using a Loss Reduction of {@link
* Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS}
*
* @param tf the TensorFlow Ops
* @param name the name of this loss
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
* <code>0.9</code> for label <code>1</code>
*/
public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) {
this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
Expand All @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la
*
* @param tf the TensorFlow Ops
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>x=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>x=0.2</code> means
* that we will use a value of <code>0.1</code> for label <code>0</code> and <code>0.9</code>
* for label <code>1</code>
* @param reduction Type of Reduction to apply to loss.
*/
public CategoricalCrossentropy(
Expand All @@ -199,13 +202,14 @@ public CategoricalCrossentropy(
* @param tf the TensorFlow Ops
* @param name the name of this loss
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
* <code>0.9</code> for label <code>1</code>
* @param reduction Type of Reduction to apply to loss.
* @param axis The channels axis. <code>axis=-1</code> corresponds to data format "Channels Last"
* and <code>axis=1</code> corresponds to data format "Channels First".
* {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
* and <code>axis=1</code> corresponds to data format "Channels First". {@link
* Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
* @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1.
*/
public CategoricalCrossentropy(
Expand Down Expand Up @@ -242,13 +246,12 @@ public CategoricalCrossentropy(
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param <T> The data type of the predictions, sampleWeights and loss.
* @param <U> The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> lPredictions;
if (!fromLogits) {
// add predictions range check for 0 - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* <p><code>loss = maximum(neg - pos + 1, 0)</code> where <code>neg=maximum((1-labels)*predictions)
* </code> and <code>pos=sum(labels*predictions)</code>
*
* <p><code>labels</code> values are expected to be 0 or 1.</p>
* <p><code>labels</code> values are expected to be 0 or 1.
*
* <p>Standalone usage:
*
Expand Down Expand Up @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) {

/** {@inheritDoc} */
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> losses = Losses.categoricalHinge(getTF(), labels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
/**
* Computes the cosine similarity between labels and predictions.
*
* <p>Note that it is a number between <code>-1</code> and <code>1</code>. When it is a negative number between <code>-1</code> and <code>0</code>, <code>0</code>
* indicates orthogonality and values closer to <code>-1</code>indicate greater similarity. The values closer to
* <code>1</code> indicate greater dissimilarity. This makes it usable as a loss function in a setting where you
* try to maximize the proximity between predictions and targets. If either <code>labels</code> or <code>predictions</code> is
* a zero vector, cosine similarity will be <code>0</code> regardless of the proximity between predictions and
* targets.
* <p>Note that it is a number between <code>-1</code> and <code>1</code>. When it is a negative
* number between <code>-1</code> and <code>0</code>, <code>0</code> indicates orthogonality and
* values closer to <code>-1</code>indicate greater similarity. The values closer to <code>1</code>
* indicate greater dissimilarity. This makes it usable as a loss function in a setting where you
* try to maximize the proximity between predictions and targets. If either <code>labels</code> or
* <code>predictions</code> is a zero vector, cosine similarity will be <code>0</code> regardless of
* the proximity between predictions and targets.
*
* <p><code>loss = -sum(l2Norm(labels) * l2Norm(predictions))</code>
*
Expand Down Expand Up @@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss {
public static final int DEFAULT_AXIS = -1;
public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO;

private final int axis;
private final int[] axis;

/**
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis
Expand Down Expand Up @@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) {

this(tf, null, axis, DEFAULT_REDUCTION);
}
/**
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a
* Loss Reduction of {@link #DEFAULT_REDUCTION}
*
* @param tf the TensorFlow Ops
* @param axis The dimension along which the cosine similarity is computed.
*/
public CosineSimilarity(Ops tf, int[] axis) {

this(tf, null, axis, DEFAULT_REDUCTION);
}

/**
* Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION}
Expand All @@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) {
this(tf, name, axis, DEFAULT_REDUCTION);
}

/**
* Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION}
*
* @param tf the TensorFlow Ops
* @param name the name of the loss
* @param axis The dimension along which the cosine similarity is computed.
*/
public CosineSimilarity(Ops tf, String name, int[] axis) {

this(tf, name, axis, DEFAULT_REDUCTION);
}

/**
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an
* axis of {@link #DEFAULT_AXIS}
Expand Down Expand Up @@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) {
*/
public CosineSimilarity(Ops tf, int axis, Reduction reduction) {

this(tf, null, new int[] {axis}, reduction);
}

/**
* Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name
*
* @param tf the TensorFlow Ops
* @param axis The dimension along which the cosine similarity is computed.
* @param reduction Type of Reduction to apply to the loss.
*/
public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) {

this(tf, null, axis, reduction);
}

Expand All @@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) {
* @param reduction Type of Reduction to apply to the loss.
*/
public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) {
this(tf, name, new int[] {axis}, reduction);
}

/**
* Creates a Cosine Similarity Loss
*
* @param tf the TensorFlow Ops
* @param name the name of the loss
* @param axis The dimension along which the cosine similarity is computed.
* @param reduction Type of Reduction to apply to the loss.
*/
public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) {
super(tf, name, reduction);
this.axis = axis;
}

/** {@inheritDoc} */
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis);
losses = tf.math.neg(losses);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Computes the hinge loss between labels and predictions.
*
* <p><code>loss = maximum(1 - labels * predictions, 0)</code></p>.
* <p><code>loss = maximum(1 - labels * predictions, 0)</code>.
*
* <p><code>labels</code> values are expected to be -1 or 1.
* If binary (0 or 1) labels are provided, they will be converted to -1 or 1.</p>
* <p><code>labels</code> values are expected to be -1 or 1. If binary (0 or 1) labels are provided,
* they will be converted to -1 or 1.
*
* <p>Standalone usage:
*
Expand Down Expand Up @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) {
* label values are not in the set [-1., 0., 1.].
*
* @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be
* -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.
* -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1.
* @param predictions the predictions, values must be in the range [0. to 1.] inclusive.
* @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is
* provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor
Expand All @@ -116,21 +117,19 @@ public Hinge(Ops tf, String name, Reduction reduction) {
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param <T> The data type of the predictions, sampleWeights and loss.
* @param <U> The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
@SuppressWarnings("unchecked")
Operand<T> tLabels = predictions.type() == labels.type() ?
(Operand<T>)labels : cast(tf, labels, predictions.type());
tLabels = LossesHelper.valueCheck(
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> tLabels = cast(tf, labels, predictions.type());
tLabels =
LossesHelper.valueCheck(
getTF(),
"labels value check [-1, 0, 1]",
tLabels,
cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type()));
cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type()));
Operand<T> losses = Losses.hinge(getTF(), tLabels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public Huber(Ops tf) {
* Loss#REDUCTION_DEFAULT}
*
* @param tf the TensorFlow Ops
* @param name the name of the loss, if null then {@link Class#getSimpleName()} is used.
*/
public Huber(Ops tf, String name) {
this(tf, name, DELTA_DEFAULT, Reduction.AUTO);
Expand All @@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) {
* Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta
*
* @param tf the TensorFlow Ops
* @param name the name of the loss, if null then {@link Class#getSimpleName()} is used.
* @param reduction Type of Reduction to apply to the loss.
*/
public Huber(Ops tf, String name, Reduction reduction) {
Expand All @@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) {
* Creates a Huber Loss
*
* @param tf the TensorFlow Ops
* @param name the name of the loss
* @param name the name of the loss, if null then {@link Class#getSimpleName()} is used.
* @param delta the point where the Huber loss function changes from quadratic to linear.
* @param reduction Type of Reduction to apply to the loss.
*/
Expand All @@ -130,8 +132,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) {

/** {@inheritDoc} */
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> losses = Losses.huber(getTF(), labels, predictions, delta);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) {

/** {@inheritDoc} */
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
Expand Down
Loading