-
Notifications
You must be signed in to change notification settings - Fork 216
Metrics init scope #382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Metrics init scope #382
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
3338605
init scope changes regenerated the core ops.
JimClarke5 0653299
Fix JavaDoc errors
JimClarke5 619fdea
change to use init scope for Variables. Moved Ops parameter out of th…
JimClarke5 d26e505
Fix JavaDoc errors
JimClarke5 ab19a39
change to use init scope for Variables. Moved Ops parameter out of th…
JimClarke5 0718d5e
Fix framework MatMul to call raw op org.tensorflow.op.linacalg.MatMul…
JimClarke5 fd68feb
Fix JavaDoc for FameworkOps
JimClarke5 770b735
Code reformat
JimClarke5 b725dab
Merge remote-tracking branch 'origin/MetricsInitScope' into MetricsIn…
JimClarke5 3bfa50c
Code reformat
JimClarke5 895aed5
Code reformat
JimClarke5 d953538
Merge remote-tracking branch 'origin/MetricsInitScope' into MetricsIn…
JimClarke5 d465240
Merge branch 'tensorflow:master' into MetricsInitScope
JimClarke5 24c2ea5
Fix JavaDoc errors
JimClarke5 353d1bb
change to use init scope for Variables. Moved Ops parameter out of th…
JimClarke5 8be8a85
Fix framework MatMul to call raw op org.tensorflow.op.linacalg.MatMul…
JimClarke5 74c5b8f
Fix JavaDoc for FameworkOps
JimClarke5 57e2467
Code reformat
JimClarke5 1ede549
Code reformat
JimClarke5 05cc957
Fix JavaDoc errors
JimClarke5 e6ad388
Code reformat
JimClarke5 1d250af
Merge branch 'tensorflow:master' into MetricsInitScope
JimClarke5 e62279c
Merge remote-tracking branch 'origin/MetricsInitScope' into MetricsIn…
JimClarke5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
313 changes: 132 additions & 181 deletions
313
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
259 changes: 259 additions & 0 deletions
259
tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BaseMetric.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.metrics; | ||
|
||
import java.util.Collections; | ||
import java.util.List; | ||
import org.tensorflow.Graph; | ||
import org.tensorflow.Operand; | ||
import org.tensorflow.op.Op; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
/** Base class for Metrics */ | ||
public abstract class BaseMetric implements Metric { | ||
|
||
/** The seed for random number generation */ | ||
private final long seed; | ||
|
||
private String name; | ||
|
||
private boolean initialized; | ||
|
||
private Ops tf; | ||
|
||
/** | ||
* Creates a Metric with a name of {@link Class#getSimpleName()} | ||
* | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
*/ | ||
protected BaseMetric(long seed) { | ||
this(null, seed); | ||
} | ||
|
||
/** | ||
* Creates a Metric | ||
* | ||
* @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
*/ | ||
protected BaseMetric(String name, long seed) { | ||
|
||
this.seed = seed; | ||
this.name = name != null ? name : this.getClass().getSimpleName(); | ||
} | ||
|
||
/** | ||
* Creates a List of Operations to update the metric state based on input values. | ||
* | ||
* <p>This is an empty implementation that should be overridden in a subclass, if needed. | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @param values the inputs to be passed to update state, this may not be null | ||
* @param sampleWeights sample weights to be applied to the values, may be null. | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. | ||
* @return a List of Operations to update the metric state | ||
*/ | ||
@SuppressWarnings({"unchecked", "unused"}) | ||
@Override | ||
public List<Op> updateStateList( | ||
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) { | ||
checkIsGraph(tf); | ||
return Collections.EMPTY_LIST; | ||
} | ||
|
||
/** | ||
* Creates a List of Operations to update the metric state based on labels and predictions. | ||
* | ||
* <p>This is an empty implementation that should be overridden in a subclass, if needed. | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @param labels the labels | ||
* @param predictions the predictions | ||
* @param sampleWeights sample weights to be applied to the metric values, may be null. | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. | ||
* @return a List of Operations to update the metric state | ||
*/ | ||
@Override | ||
@SuppressWarnings({"unchecked", "unused"}) | ||
public List<Op> updateStateList( | ||
Ops tf, | ||
Operand<? extends TNumber> labels, | ||
Operand<? extends TNumber> predictions, | ||
Operand<? extends TNumber> sampleWeights) { | ||
checkIsGraph(tf); | ||
return Collections.EMPTY_LIST; | ||
} | ||
|
||
/** | ||
* Creates a NoOp Operation with control dependencies to update the metric state | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @param values the inputs to be passed to update state, this may not be null | ||
* @param sampleWeights sample weights to be applied to the values, may be null. | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. | ||
* @return the Operation to update the metric state | ||
*/ | ||
public final Op updateState( | ||
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) { | ||
checkIsGraph(tf); | ||
List<Op> controlOps = updateStateList(tf, values, sampleWeights); | ||
return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); | ||
} | ||
|
||
/** | ||
* Creates a NoOp Operation with control dependencies to update the metric state | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @param labels the labels | ||
* @param predictions the predictions | ||
* @param sampleWeights sample weights to be applied to the metric values, may be null. | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. | ||
* @return the Operation to update the metric state | ||
*/ | ||
public final Op updateState( | ||
Ops tf, | ||
Operand<? extends TNumber> labels, | ||
Operand<? extends TNumber> predictions, | ||
Operand<? extends TNumber> sampleWeights) { | ||
List<Op> controlOps = updateStateList(tf, labels, predictions, sampleWeights); | ||
return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); | ||
} | ||
|
||
/** | ||
* Calls update state once, followed by a call to get the result | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @param values the inputs to be passed to update state, this may not be null | ||
* @param sampleWeights sample weights to be applied to the values, may be null. | ||
* @param <T> The data type for the metric result | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. | ||
* @return the result, possibly with control dependencies | ||
*/ | ||
@Override | ||
public final <T extends TNumber> Operand<T> callOnce( | ||
Ops tf, | ||
Operand<? extends TNumber> values, | ||
Operand<? extends TNumber> sampleWeights, | ||
Class<T> type) { | ||
checkIsGraph(tf); | ||
List<Op> controlOps = updateStateList(tf, values, sampleWeights); | ||
Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); | ||
return ltf.identity(result(ltf, type)); | ||
} | ||
|
||
/** | ||
* Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. | ||
* | ||
* @param varName the base name for the variable | ||
* @return the formatted variable name | ||
*/ | ||
protected String getVariableName(String varName) { | ||
return String.format("%s_%s", this.name, varName); | ||
} | ||
|
||
/** | ||
* The name for this metric. Defaults to {@link Class#getSimpleName()}. | ||
* | ||
* <p>Gets the name of this metric. | ||
* | ||
* @return the name of this metric | ||
*/ | ||
public String getName() { | ||
return name; | ||
} | ||
|
||
/** | ||
* Sets the metric name | ||
* | ||
* @param name the metric name | ||
*/ | ||
public void setName(String name) { | ||
this.name = name; | ||
} | ||
|
||
/** | ||
* Gets the random number generator seed value | ||
* | ||
* @return the random number generator seed value | ||
*/ | ||
public long getSeed() { | ||
return seed; | ||
} | ||
|
||
/** | ||
* Initialize the TensorFlow Ops | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @throws IllegalArgumentException if the TensorFlow Ops does not have a Graph environment, | ||
*/ | ||
protected abstract void init(Ops tf); | ||
|
||
/** | ||
* Gets the TensorFlow Ops for this metric | ||
* | ||
* @return the TensorFlow Ops for this metric. | ||
*/ | ||
protected Ops getTF() { | ||
return tf; | ||
} | ||
|
||
/** | ||
* Sets the TensorFlow Ops for this metric. | ||
* | ||
* <p>This should be set from the {@link #init(Ops)} implementation. | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment. | ||
*/ | ||
protected void setTF(Ops tf) { | ||
checkIsGraph(tf); | ||
this.tf = tf; | ||
} | ||
|
||
/** | ||
* Checks whether the Metric is initialized or not. | ||
* | ||
* @return true if the Metric has been initialized. | ||
*/ | ||
public boolean isInitialized() { | ||
return initialized; | ||
} | ||
|
||
/** | ||
* Sets the initialized indicator | ||
* | ||
* @param initialized the initialized indicator | ||
*/ | ||
protected void setInitialized(boolean initialized) { | ||
this.initialized = initialized; | ||
} | ||
|
||
/** | ||
* Checks if the TensorFlow Ops encapsulates a {@link Graph} environment. | ||
* | ||
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment. | ||
* @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph | ||
* environment. | ||
*/ | ||
protected void checkIsGraph(Ops tf) { | ||
if (!tf.scope().env().isGraph()) { | ||
throw new IllegalArgumentException( | ||
"The Ops environment is not a Graph, Graph is required for metrics."); | ||
} | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do
BaseMetric
still needs to store and instance ofOps
since it is already available in most endpoints? I can't see a case where it is needed in this PR (but it is not obvious to search in a PR neither)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that once the
Metric
's variables are created, then subsequent operations would have to use the sameGraph
to manipulate them. That is why I am storing it. Let me know if my understanding is wrong.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a unit test showing how the metrics are manipulated then after initialization?