Skip to content

Generic cleanup Metrics and Losses #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
8f57a7a
Initial checkin
JimClarke5 Jan 1, 2021
090bde4
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
ce5fa27
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
475ca36
JavaDoc cleanup
JimClarke5 Jan 1, 2021
b48c09a
Javadoc fixes
JimClarke5 Jan 3, 2021
c4c06de
Change LossInterface to LossMetric.
JimClarke5 Jan 5, 2021
2d4c17b
Removed hashmap for variables, they are not needed as the variables o…
JimClarke5 Jan 7, 2021
fbb12f4
reformat code
JimClarke5 Jan 7, 2021
30196af
Add tests for assertBroadcastable
JimClarke5 Jan 11, 2021
1248581
Change type to resultType
JimClarke5 Jan 11, 2021
96bd55b
Added V data type for sampleWeights so that it is not forced to be th…
JimClarke5 Jan 11, 2021
d706f55
change 'type' to 'resultType'
JimClarke5 Jan 11, 2021
0147442
clean up mean and fix assert assertBroadcastable
JimClarke5 Jan 11, 2021
c0c127f
fix error message
JimClarke5 Jan 11, 2021
89ec9ed
Change sampleWeights to have its own generic type <S extends TNumber>
JimClarke5 Jan 12, 2021
1b81c82
Add commment about invalid tests expecting IllegalArgumentExceptions
JimClarke5 Jan 12, 2021
54d4ae9
Add this exception instead of the more generic IllegalArgumentExcepti…
JimClarke5 Jan 12, 2021
ca50dfa
change IllegalArgumentException to NotBroadcastableException.
JimClarke5 Jan 12, 2021
ca5ad09
reformat code
JimClarke5 Jan 12, 2021
1e72017
Fis=x Javadoc
JimClarke5 Jan 13, 2021
85bdde7
Fix Reduce to use boradcastWeights,
JimClarke5 Jan 17, 2021
3f46bd2
Added comment to count to indicate that it may be weighted.
JimClarke5 Jan 17, 2021
fe65ae7
Added SetsOps and fixed AssertBroadcastable to use SetsOps methods,
JimClarke5 Jan 19, 2021
f8d38cf
Fixed based on various PR comments.
JimClarke5 Jan 19, 2021
00ce5db
Deleted, no longer needed after change to Variable handling in Metrics.
JimClarke5 Jan 19, 2021
51104f1
Fix Losses to use CHANNELS_FIRST/LAST for CategoricalCrossentropy
JimClarke5 Jan 20, 2021
e918df4
Fix SetOps to properly convert sparse tensor to dense tensor using tf…
JimClarke5 Jan 30, 2021
9cdc274
Initial checkin
JimClarke5 Jan 1, 2021
141ebd5
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
360a2dc
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
8ec0390
JavaDoc cleanup
JimClarke5 Jan 1, 2021
14de446
Javadoc fixes
JimClarke5 Jan 3, 2021
d0a7fd6
Change LossInterface to LossMetric.
JimClarke5 Jan 5, 2021
feb430c
Removed hashmap for variables, they are not needed as the variables o…
JimClarke5 Jan 7, 2021
48390ea
reformat code
JimClarke5 Jan 7, 2021
6a74ce8
Add tests for assertBroadcastable
JimClarke5 Jan 11, 2021
c62481b
Change type to resultType
JimClarke5 Jan 11, 2021
d475b1a
Added V data type for sampleWeights so that it is not forced to be th…
JimClarke5 Jan 11, 2021
8f530cc
change 'type' to 'resultType'
JimClarke5 Jan 11, 2021
947482a
clean up mean and fix assert assertBroadcastable
JimClarke5 Jan 11, 2021
f473568
fix error message
JimClarke5 Jan 11, 2021
99ff15d
Change sampleWeights to have its own generic type <S extends TNumber>
JimClarke5 Jan 12, 2021
f4e3e04
Add commment about invalid tests expecting IllegalArgumentExceptions
JimClarke5 Jan 12, 2021
4efdb62
Add this exception instead of the more generic IllegalArgumentExcepti…
JimClarke5 Jan 12, 2021
b0a143e
change IllegalArgumentException to NotBroadcastableException.
JimClarke5 Jan 12, 2021
5e05018
reformat code
JimClarke5 Jan 12, 2021
b7530a3
Fis=x Javadoc
JimClarke5 Jan 13, 2021
a2761f0
Fix Reduce to use boradcastWeights,
JimClarke5 Jan 17, 2021
c2efa2a
Added comment to count to indicate that it may be weighted.
JimClarke5 Jan 17, 2021
9db5767
Added SetsOps and fixed AssertBroadcastable to use SetsOps methods,
JimClarke5 Jan 19, 2021
fb8aa65
Fixed based on various PR comments.
JimClarke5 Jan 19, 2021
02be174
Deleted, no longer needed after change to Variable handling in Metrics.
JimClarke5 Jan 19, 2021
25b061a
Remove extra generics from op generation (#193)
rnett Jan 26, 2021
a1c5187
Fix SetOps to properly convert sparse tensor to dense tensor using tf…
JimClarke5 Jan 30, 2021
b125294
Simplify generic parameters across losses and metrics.
JimClarke5 Feb 1, 2021
78366de
Reformat code
JimClarke5 Feb 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ public final class Ops {

public final SignalOps signal;

public final TrainOps train;

public final QuantizationOps quantization;

public final TrainOps train;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -370,8 +370,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
}

/**
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,12 @@ public BinaryCrossentropy(
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param <T> The data type of the predictions, sampleWeights and loss.
* @param <U> The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> lPredictions;
if (!fromLogits) {
// add predictions range check for 0 - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
public class CategoricalCrossentropy extends Loss {
public static final boolean FROM_LOGITS_DEFAULT = false;
public static final float LABEL_SMOOTHING_DEFAULT = 0.0f;
public static final int DEFAULT_AXIS = -1;
public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST;

private final boolean fromLogits;
private final float labelSmoothing;
Expand Down Expand Up @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) {
*
* @param tf the TensorFlow Ops
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
* <code>0.9</code> for label <code>1</code>
*/
public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) {
this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
}

/**
* Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT},
* and a channel axis of {@link #DEFAULT_AXIS}
* Creates a categorical cross entropy Loss using a Loss Reduction of {@link
* Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS}
*
* @param tf the TensorFlow Ops
* @param name the name of this loss
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
* <code>0.9</code> for label <code>1</code>
*/
public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) {
this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS);
Expand All @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la
*
* @param tf the TensorFlow Ops
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>x=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>x=0.2</code> means
* that we will use a value of <code>0.1</code> for label <code>0</code> and <code>0.9</code>
* for label <code>1</code>
* @param reduction Type of Reduction to apply to loss.
*/
public CategoricalCrossentropy(
Expand All @@ -199,12 +202,14 @@ public CategoricalCrossentropy(
* @param tf the TensorFlow Ops
* @param name the name of this loss
* @param fromLogits Whether to interpret predictions as a tensor of logit values
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are smoothed, meaning the
* confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code> means that we will use a
* value of <code>0.1</code> for label <code>0</code> and <code>0.9</code> for label <code>1</code>
* @param labelSmoothing Float in <code>[0, 1]</code>. When <code>&gt; 0</code>, label values are
* smoothed, meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2
* </code> means that we will use a value of <code>0.1</code> for label <code>0</code> and
* <code>0.9</code> for label <code>1</code>
* @param reduction Type of Reduction to apply to loss.
* @param axis The channels axis. <code>axis=-1</code> corresponds to data format `Channels Last'
* and <code>axis=1</code> corresponds to data format 'Channels First'.
* @param axis The channels axis. <code>axis=-1</code> corresponds to data format "Channels Last"
* and <code>axis=1</code> corresponds to data format "Channels First". {@link
* Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST}
* @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1.
*/
public CategoricalCrossentropy(
Expand Down Expand Up @@ -241,13 +246,12 @@ public CategoricalCrossentropy(
* predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param <T> The data type of the predictions, sampleWeights and loss.
* @param <U> The data type of the labels.
* @return the loss
* @throws IllegalArgumentException if the predictions are outside the range [0.-1.].
*/
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> lPredictions;
if (!fromLogits) {
// add predictions range check for 0 - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* <p><code>loss = maximum(neg - pos + 1, 0)</code> where <code>neg=maximum((1-labels)*predictions)
* </code> and <code>pos=sum(labels*predictions)</code>
*
* <p><code>labels</code> values are expected to be 0 or 1.</p>
* <p><code>labels</code> values are expected to be 0 or 1.
*
* <p>Standalone usage:
*
Expand Down Expand Up @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) {

/** {@inheritDoc} */
@Override
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
public <T extends TNumber> Operand<T> call(
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Operand<T> losses = Losses.categoricalHinge(getTF(), labels, predictions);
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);
}
Expand Down
Loading