Skip to content

Commit 268ee46

Browse files
authored
Fetch resource variable fix (#276)
* Fixing SavedModelBundleTest.pythonTfFunction. * Fixed 3 javadoc errors in tensorflow-core-api's hand written code. * Changing the surefire parameters for tensorflow-core-api.
1 parent e229028 commit 268ee46

File tree

7 files changed

+26
-15
lines changed

7 files changed

+26
-15
lines changed

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,10 @@
388388
</execution>
389389
</executions>
390390
<configuration>
391-
<!-- Activate the use of TCP to transmit events to the plugin -->
392-
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
391+
<!-- Activate the use of TCP to transmit events to the plugin -->
392+
<!-- disabled as it appears to cause intermittent test failures in GitHub Actions
393+
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
394+
-->
393395
<additionalClasspathElements>
394396
<additionalClasspathElement>${project.build.directory}/${project.artifactId}-${project.version}-${native.classifier}.jar</additionalClasspathElement>
395397
<!-- Note: the following path is not accessible in deploying profile, so other libraries like

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorMapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
* <p>Usage of this class is reserved for internal purposes only.
2727
*
2828
* @param <T> tensor type mapped by this object
29-
* @see {@link TType}
29+
* @see TType
3030
*/
3131
public abstract class TensorMapper<T extends TType> {
3232

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Zeros.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
* An operator creating a constant initialized with zeros of the shape given by `dims`.
3131
*
3232
* <p>For example, the following expression
33-
* <pre>{@code tf.zeros(tf.constant(shape), TFloat32.class)</pre>
33+
* <pre>{@code tf.zeros(tf.constant(shape), TFloat32.class)}</pre>
3434
* is the equivalent of
35-
* <pre>{@code tf.fill(tf.constant(shape), tf.constant(0.0f))</pre>
35+
* <pre>{@code tf.fill(tf.constant(shape), tf.constant(0.0f))}</pre>
3636
*
3737
* @param <T> constant type
3838
*/

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
*
2929
* <p>Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of
3030
* TensorFlow to identify the type of the tensor they carry. For example, a
31-
* {@link org.tensorflow.Operand Operand<TFloat32>} is an operand which outputs a 32-bit floating
31+
* {@link org.tensorflow.Operand Operand&lt;TFloat32&gt;} is an operand which outputs a 32-bit floating
3232
* point tensor. This parameter ensure type-compatibility between operands of a computation at
3333
* compile-time. For example:
3434
*

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ public void pythonTfFunction() {
290290
* Signature name used for saving 'add', argument names 'a' and 'b'
291291
*/
292292
ConcreteFunction add = bundle.function("add");
293-
Map<String, Tensor> args = new HashMap();
293+
Map<String, Tensor> args = new HashMap<>();
294294
try (TFloat32 a = TFloat32.scalarOf(10.0f);
295295
TFloat32 b = TFloat32.scalarOf(15.5f)) {
296296
args.put("a", a);
@@ -301,14 +301,19 @@ public void pythonTfFunction() {
301301
assertEquals(25.5f, c.getFloat());
302302
}
303303
}
304+
args.clear();
304305

305306
// variable unwrapping happens in Session, which is used by ConcreteFunction.call
306307
ConcreteFunction getVariable = bundle.function("get_variable");
307-
try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>())
308-
.get(getVariable.signature().outputNames().iterator().next())) {
309-
assertEquals(2f, v.getFloat());
308+
try (TFloat32 dummy = TFloat32.scalarOf(1.0f)) {
309+
args.put("dummy",dummy);
310+
// TF functions always require an input, so we supply a dummy one here
311+
// This test actually checks that resource variables can be loaded correctly.
312+
try (TFloat32 v = (TFloat32) getVariable.call(args)
313+
.get(getVariable.signature().outputNames().iterator().next())) {
314+
assertEquals(2f, v.getFloat());
315+
}
310316
}
311-
312317
}
313318
}
314319

tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import tensorflow as tf
2424

25+
2526
class MyModel(tf.keras.Model):
2627
def __init__(self):
2728
super(MyModel, self).__init__()
@@ -42,8 +43,7 @@ def get_scalar(self, x):
4243
def get_vector(self, x):
4344
return self.const_vector + x
4445

45-
@tf.function(input_signature=[
46-
tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
46+
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
4747
def get_matrix(self, x):
4848
return self.const_matrix + x
4949

@@ -53,10 +53,14 @@ def get_matrix(self, x):
5353
def add(self, a, b):
5454
return a + b
5555

56-
@tf.function(input_signature=[])
57-
def get_variable(self):
56+
#TF functions always require an input
57+
@tf.function(input_signature=[
58+
tf.TensorSpec(shape=None, dtype=tf.float32, name='dummy')
59+
])
60+
def get_variable(self, dummy):
5861
return self.variable
5962

63+
6064
model = MyModel()
6165

6266
signatures = {

0 commit comments

Comments
 (0)