Skip to content

Metrics init scope #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3338605
init scope changes regenerated the core ops.
JimClarke5 Sep 9, 2021
0653299
Fix JavaDoc errors
JimClarke5 Sep 9, 2021
619fdea
change to use init scope for Variables. Moved Ops parameter out of th…
JimClarke5 Sep 9, 2021
d26e505
Fix JavaDoc errors
JimClarke5 Sep 9, 2021
ab19a39
change to use init scope for Variables. Moved Ops parameter out of th…
JimClarke5 Sep 9, 2021
0718d5e
Fix framework MatMul to call raw op org.tensorflow.op.linacalg.MatMul…
JimClarke5 Sep 17, 2021
fd68feb
Fix JavaDoc for FameworkOps
JimClarke5 Sep 17, 2021
770b735
Code reformat
JimClarke5 Sep 17, 2021
b725dab
Merge remote-tracking branch 'origin/MetricsInitScope' into MetricsIn…
JimClarke5 Sep 17, 2021
3bfa50c
Code reformat
JimClarke5 Sep 17, 2021
895aed5
Code reformat
JimClarke5 Sep 17, 2021
d953538
Merge remote-tracking branch 'origin/MetricsInitScope' into MetricsIn…
JimClarke5 Oct 5, 2021
d465240
Merge branch 'tensorflow:master' into MetricsInitScope
JimClarke5 Oct 5, 2021
24c2ea5
Fix JavaDoc errors
JimClarke5 Sep 9, 2021
353d1bb
change to use init scope for Variables. Moved Ops parameter out of th…
JimClarke5 Sep 9, 2021
8be8a85
Fix framework MatMul to call raw op org.tensorflow.op.linacalg.MatMul…
JimClarke5 Sep 17, 2021
74c5b8f
Fix JavaDoc for FameworkOps
JimClarke5 Sep 17, 2021
57e2467
Code reformat
JimClarke5 Sep 17, 2021
1ede549
Code reformat
JimClarke5 Sep 17, 2021
05cc957
Fix JavaDoc errors
JimClarke5 Sep 9, 2021
e6ad388
Code reformat
JimClarke5 Sep 17, 2021
1d250af
Merge branch 'tensorflow:master' into MetricsInitScope
JimClarke5 Oct 20, 2021
e62279c
Merge remote-tracking branch 'origin/MetricsInitScope' into MetricsIn…
JimClarke5 Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import static org.tensorflow.framework.utils.CastHelper.cast;

import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.framework.metrics.impl.LossMetric;
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper;
import org.tensorflow.framework.metrics.impl.MetricsHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
Expand All @@ -36,50 +37,54 @@
*
* @param <T> The data type for the metric result
*/
public class Accuracy<T extends TNumber> extends MeanMetricWrapper<T> implements LossMetric<T> {
public class Accuracy<T extends TNumber> extends MeanBaseMetricWrapper<T> implements LossMetric {

/**
* Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name
*
* @param tf the TensorFlow Ops
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public Accuracy(Ops tf, long seed, Class<T> type) {
this(tf, null, seed, type);
public Accuracy(long seed, Class<T> type) {
this(null, seed, type);
}

/**
* Creates an Accuracy Metric
*
* @param tf the TensorFlow Ops
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public Accuracy(Ops tf, String name, long seed, Class<T> type) {
super(tf, name, seed, type);
public Accuracy(String name, long seed, Class<T> type) {
super(name, seed, type);
setLoss(this);
}

/**
* Calculates how often predictions equals labels. {@code labels} and {@code predictions} must
* have compatible shapes, see {@link Shape @isCompatibleWith}.
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @param labels the truth values or labels
* @param predictions the predictions
* @throws IllegalArgumentException if predictions and labels shapes are not compatible.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph
* environment.
* @return the loss
*/
@Override
public Operand<T> call(
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) {
Operand<T> tLabels = cast(getTF(), labels, getResultType());
Operand<T> tPredictions = cast(getTF(), predictions, getResultType());
public <U extends TNumber> Operand<U> call(
Ops tf,
Operand<? extends TNumber> labels,
Operand<? extends TNumber> predictions,
Class<U> resultType) {
init(tf);
Operand<T> tLabels = cast(tf, labels, getInternalType());
Operand<T> tPredictions = cast(tf, predictions, getInternalType());
LossTuple<T> tuple =
MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions);
MetricsHelper.raggedAssertCompatibleAndGetFlatValues(tf, tLabels, tPredictions);
tLabels = tuple.getLabels();
tPredictions = tuple.getTarget();

Expand All @@ -91,6 +96,6 @@ public Operand<T> call(
}

// cast TBool to result type
return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType());
return cast(tf, tf.math.equal(tLabels, tPredictions), resultType);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.metrics;

import java.util.Collections;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/** Base class for Metrics */
public abstract class BaseMetric implements Metric {

/** The seed for random number generation */
private final long seed;

private String name;

private boolean initialized;

private Ops tf;

/**
* Creates a Metric with a name of {@link Class#getSimpleName()}
*
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
*/
protected BaseMetric(long seed) {
this(null, seed);
}

/**
* Creates a Metric
*
* @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}.
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
*/
protected BaseMetric(String name, long seed) {

this.seed = seed;
this.name = name != null ? name : this.getClass().getSimpleName();
}

/**
* Creates a List of Operations to update the metric state based on input values.
*
* <p>This is an empty implementation that should be overridden in a subclass, if needed.
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @param values the inputs to be passed to update state, this may not be null
* @param sampleWeights sample weights to be applied to the values, may be null.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
* @return a List of Operations to update the metric state
*/
@SuppressWarnings({"unchecked", "unused"})
@Override
public List<Op> updateStateList(
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) {
checkIsGraph(tf);
return Collections.EMPTY_LIST;
}

/**
* Creates a List of Operations to update the metric state based on labels and predictions.
*
* <p>This is an empty implementation that should be overridden in a subclass, if needed.
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @param labels the labels
* @param predictions the predictions
* @param sampleWeights sample weights to be applied to the metric values, may be null.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
* @return a List of Operations to update the metric state
*/
@Override
@SuppressWarnings({"unchecked", "unused"})
public List<Op> updateStateList(
Ops tf,
Operand<? extends TNumber> labels,
Operand<? extends TNumber> predictions,
Operand<? extends TNumber> sampleWeights) {
checkIsGraph(tf);
return Collections.EMPTY_LIST;
}

/**
* Creates a NoOp Operation with control dependencies to update the metric state
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @param values the inputs to be passed to update state, this may not be null
* @param sampleWeights sample weights to be applied to the values, may be null.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
* @return the Operation to update the metric state
*/
public final Op updateState(
Ops tf, Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights) {
checkIsGraph(tf);
List<Op> controlOps = updateStateList(tf, values, sampleWeights);
return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp();
}

/**
* Creates a NoOp Operation with control dependencies to update the metric state
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @param labels the labels
* @param predictions the predictions
* @param sampleWeights sample weights to be applied to the metric values, may be null.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
* @return the Operation to update the metric state
*/
public final Op updateState(
Ops tf,
Operand<? extends TNumber> labels,
Operand<? extends TNumber> predictions,
Operand<? extends TNumber> sampleWeights) {
List<Op> controlOps = updateStateList(tf, labels, predictions, sampleWeights);
return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp();
}

/**
* Calls update state once, followed by a call to get the result
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @param values the inputs to be passed to update state, this may not be null
* @param sampleWeights sample weights to be applied to the values, may be null.
* @param <T> The data type for the metric result
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
* @return the result, possibly with control dependencies
*/
@Override
public final <T extends TNumber> Operand<T> callOnce(
Ops tf,
Operand<? extends TNumber> values,
Operand<? extends TNumber> sampleWeights,
Class<T> type) {
checkIsGraph(tf);
List<Op> controlOps = updateStateList(tf, values, sampleWeights);
Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps);
return ltf.identity(result(ltf, type));
}

/**
* Gets a formatted name for a variable, in the form {@link #name} + "_" + varName.
*
* @param varName the base name for the variable
* @return the formatted variable name
*/
protected String getVariableName(String varName) {
return String.format("%s_%s", this.name, varName);
}

/**
* The name for this metric. Defaults to {@link Class#getSimpleName()}.
*
* <p>Gets the name of this metric.
*
* @return the name of this metric
*/
public String getName() {
return name;
}

/**
* Sets the metric name
*
* @param name the metric name
*/
public void setName(String name) {
this.name = name;
}

/**
* Gets the random number generator seed value
*
* @return the random number generator seed value
*/
public long getSeed() {
return seed;
}

/**
* Initialize the TensorFlow Ops
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @throws IllegalArgumentException if the TensorFlow Ops does not have a Graph environment,
*/
protected abstract void init(Ops tf);

/**
* Gets the TensorFlow Ops for this metric
*
* @return the TensorFlow Ops for this metric.
*/
protected Ops getTF() {
return tf;
}

/**
* Sets the TensorFlow Ops for this metric.
*
* <p>This should be set from the {@link #init(Ops)} implementation.
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not have a Graph environment.
*/
protected void setTF(Ops tf) {
checkIsGraph(tf);
this.tf = tf;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do BaseMetric still needs to store and instance of Ops since it is already available in most endpoints? I can't see a case where it is needed in this PR (but it is not obvious to search in a PR neither)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that once the Metric's variables are created, then subsequent operations would have to use the same Graph to manipulate them. That is why I am storing it. Let me know if my understanding is wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a unit test showing how the metrics are manipulated then after initialization?


/**
* Checks whether the Metric is initialized or not.
*
* @return true if the Metric has been initialized.
*/
public boolean isInitialized() {
return initialized;
}

/**
* Sets the initialized indicator
*
* @param initialized the initialized indicator
*/
protected void setInitialized(boolean initialized) {
this.initialized = initialized;
}

/**
* Checks if the TensorFlow Ops encapsulates a {@link Graph} environment.
*
* @param tf the TensorFlow Ops encapsulating a {@link Graph} environment.
* @throws IllegalArgumentException if the TensorFlow Ops scope does not encapsulate a Graph
* environment.
*/
protected void checkIsGraph(Ops tf) {
if (!tf.scope().env().isGraph()) {
throw new IllegalArgumentException(
"The Ops environment is not a Graph, Graph is required for metrics.");
}
}
}
Loading