Skip to content

Commit fdb93a0

Browse files
committed
SessionFunction instead of SavedModelBundle specific class
Signed-off-by: Ryan Nett <[email protected]>
1 parent 74b4611 commit fdb93a0

File tree

6 files changed

+423
-312
lines changed

6 files changed

+423
-312
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow;
18+
19+
import java.util.LinkedHashMap;
20+
import java.util.Map;
21+
import org.tensorflow.Signature.TensorDescription;
22+
23+
public interface CallableFunction {
24+
25+
/**
26+
* Returns the signature of this function
27+
*/
28+
Signature signature();
29+
30+
/**
31+
* Invokes a function using the default eager session.
32+
*
33+
* <p>Caller is responsible for closing all Tensors.
34+
*
35+
* @param arguments list of tensors to pass in input to the function, mapped by their signature name
36+
* @return output tensors resulting from the execution of the function, mapped by their signature name
37+
* @throws IllegalArgumentException if the passed arguments don't match up to the function's parameters.
38+
*/
39+
Map<String, Tensor> call(Map<String, Tensor> arguments);
40+
41+
/**
42+
* Invokes a function with a single input and output using the default eager session.
43+
*
44+
* <p>Caller is responsible for closing all Tensors.
45+
*
46+
* @param tensor input tensor
47+
* @return output tensor
48+
* @throws IllegalArgumentException if there are multiple input or output parameters defined in the function
49+
*/
50+
default Tensor call(Tensor tensor) {
51+
if (signature().inputNames().size() > 1) {
52+
throw new IllegalArgumentException(
53+
"Can't use call(Tensor) on function \"" + signature().methodName() + "\" with more than one input.");
54+
}
55+
if (signature().inputNames().size() < 1) {
56+
throw new IllegalArgumentException(
57+
"Can't use call(Tensor) on function \"" + signature().methodName() + "\" with no inputs.");
58+
}
59+
if (signature().outputNames().size() > 1) {
60+
throw new IllegalArgumentException(
61+
"Can't use call(Tensor) on function \"" + signature().methodName() + "\" with more than one output.");
62+
}
63+
if (signature().outputNames().size() < 1) {
64+
throw new IllegalArgumentException(
65+
"Can't use call(Tensor) on function \"" + signature().methodName() + "\" with no outputs.");
66+
}
67+
68+
String inputName = signature().inputNames().iterator().next();
69+
String outputName = signature().outputNames().iterator().next();
70+
71+
Map<String, Tensor> inputMap = new LinkedHashMap<>();
72+
inputMap.put(inputName, tensor);
73+
74+
return call(inputMap).get(outputName);
75+
}
76+
77+
static Operand<?> validateDescription(TensorDescription description, Graph graph, String name, String prefix) {
78+
Output<?> operand = graph.output(description.name);
79+
if (operand == null) {
80+
throw new IllegalArgumentException(
81+
prefix + " \"" + name + "\"'s operand \"" + description.name + "\" does not exist on the session's graph.");
82+
}
83+
84+
if (operand.dataType() != description.dataType) {
85+
throw new IllegalArgumentException(
86+
prefix + " \"" + name + "\"'s operand \"" + description.name + "\" has data type " + operand.dataType()
87+
+ " in the session's graph, but the signature requires data type " + description.dataType + ".");
88+
}
89+
90+
if (!operand.shape().isCompatibleWith(description.shape)) {
91+
throw new IllegalArgumentException(
92+
prefix + " \"" + name + "\"'s operand \"" + description.name + "\" has shape " + operand.shape()
93+
+ ", which is incompatible with the signature's required shape of " + description.shape + ".");
94+
}
95+
return operand;
96+
}
97+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 44 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto;
1919
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;
2020

21-
import java.io.IOException;
2221
import java.util.ArrayList;
2322
import java.util.Arrays;
2423
import java.util.Collection;
@@ -65,19 +64,19 @@
6564
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
6665
* }</pre>
6766
*/
68-
public class ConcreteFunction implements AutoCloseable {
67+
public class ConcreteFunction implements AutoCloseable, CallableFunction {
6968

7069

7170
/**
7271
* Creates a function by building a new graph.
7372
*
7473
* <p>The {@code functionBuilder} must initialize the function graph from the provided
75-
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors
76-
* and fetch the output tensors on execution.
74+
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output
75+
* tensors on execution.
7776
*
7877
* <p>The function will be the owner of the new graph and its resulting session. Therefore,
79-
* the function must be enclosed properly with a try-with-resources block to guarantee that all
80-
* native resources will be freed once the function is discarded. For example:
78+
* the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will
79+
* be freed once the function is discarded. For example:
8180
*
8281
* <pre>{@code
8382
* public class MyModel {
@@ -112,8 +111,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
112111
* Create a function from a signature and an existing graph.
113112
*
114113
* <p>The function will keep the ownership of the session used to run the graph but not
115-
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the
116-
* function. For example:
114+
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For
115+
* example:
117116
*
118117
* <pre>{@code
119118
* try (Graph g = new Graph()) {
@@ -130,7 +129,7 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
130129
* }</pre>
131130
*
132131
* @param signature signature of the function to create
133-
* @param graph a valid and initialized graph
132+
* @param graph a valid and initialized graph
134133
* @return a new function
135134
*/
136135
public static ConcreteFunction create(Signature signature, Graph graph) {
@@ -141,8 +140,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
141140
* Create a function from a signature and a valid graph session.
142141
*
143142
* <p>The function will not own the session nor its graph, meaning that their lifetime
144-
* can extend beyond the scope of the function. Therefore the function does not need to be closed
145-
* after its usage. For example:
143+
* can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For
144+
* example:
146145
*
147146
* <pre>{@code
148147
* try (Graph g = new Graph()) {
@@ -164,7 +163,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
164163
* }</pre>
165164
*
166165
* @param signature signature of the function to create
167-
* @param session a valid session to an initialized graph
166+
* @param session a valid session to an initialized graph
168167
* @return a new function
169168
*/
170169
public static ConcreteFunction create(Signature signature, Session session) {
@@ -174,6 +173,7 @@ public static ConcreteFunction create(Signature signature, Session session) {
174173
/**
175174
* Returns the signature of this function
176175
*/
176+
@Override
177177
public Signature signature() {
178178
return signature;
179179
}
@@ -220,10 +220,10 @@ public String toString() {
220220

221221

222222
/**
223-
* Calls the function in an execution environment, adding it's graph as a function if it isn't
224-
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
223+
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
224+
* inputs and outputs are keyed by the names set in the {@code Signature}.
225225
*
226-
* @param scope the scope to call the function in
226+
* @param scope the scope to call the function in
227227
* @param arguments the arguments to the call
228228
* @return the outputs of the function
229229
*/
@@ -235,12 +235,17 @@ public Map<String, Operand<?>> call(Scope scope,
235235

236236
int i = 0;
237237
for (String inputName : signature().inputNames()) {
238-
Operand<?> input = arguments.get(inputName);
239-
if (input == null) {
238+
if (!arguments.containsKey(inputName)) {
240239
throw new IllegalArgumentException(
241240
"Function " + signature().methodName() + " has parameter \"" + inputName
242241
+ "\", but no argument was passed for it.");
243242
}
243+
244+
Operand<?> input = arguments.get(inputName);
245+
if (input == null) {
246+
throw new IllegalArgumentException(
247+
"Can't pass null as an argument to a function. Argument \"" + inputName + "\" was null.");
248+
}
244249
inputs[i] = input.asOutput();
245250
i++;
246251
}
@@ -288,10 +293,10 @@ public Map<String, Operand<?>> call(Scope scope,
288293
}
289294

290295
/**
291-
* Calls the function in an execution environment, adding it's graph as a function if it isn't
292-
* already present. Only works for functions with a single input and output.
296+
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
297+
* works for functions with a single input and output.
293298
*
294-
* @param scope the scope to call the function in
299+
* @param scope the scope to call the function in
295300
* @param argument the argument to the call
296301
* @return the output of the function
297302
*/
@@ -316,18 +321,8 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
316321
return call(scope, inputMap).get(outputName);
317322
}
318323

319-
/**
320-
* Invokes a function using the default eager session.
321-
*
322-
* <p>Caller is responsible for closing all Tensors.
323-
*
324-
* @param arguments list of tensors to pass in input to the function, mapped by their signature
325-
* name
326-
* @return output tensors resulting from the execution of the function, mapped by their signature
327-
* name
328-
*/
329-
public Map<String, Tensor> call(Map<String, Tensor> arguments)
330-
throws IllegalArgumentException {
324+
@Override
325+
public Map<String, Tensor> call(Map<String, Tensor> arguments) {
331326
//FIXME need to manage input/output operand lifetimes
332327
Ops tf = Ops.create();
333328
Map<String, Operand<?>> inputs = new LinkedHashMap<>(arguments.size());
@@ -345,27 +340,10 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
345340
}
346341

347342
/**
348-
* Invokes a function with a single input and output using the default eager session.
343+
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
344+
* inputs and outputs are keyed by the names set in the {@code Signature}.
349345
*
350-
* <p>Caller is responsible for closing all Tensors.
351-
*
352-
* @param tensor input tensor
353-
* @return output tensor
354-
* @throws IllegalArgumentException if there are multiple input or output parameters defined in
355-
* the function
356-
*/
357-
public Tensor call(Tensor tensor) throws IllegalArgumentException {
358-
Ops tf = Ops.create();
359-
Operand<?> argument = tf.constantOf((TType) tensor);
360-
Operand<?> output = call(tf, argument);
361-
return output.asTensor();
362-
}
363-
364-
/**
365-
* Calls the function in an execution environment, adding it's graph as a function if it isn't
366-
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
367-
*
368-
* @param tf the scope to call the function in
346+
* @param tf the scope to call the function in
369347
* @param arguments the arguments to the call
370348
* @return the outputs of the function
371349
*/
@@ -374,30 +352,17 @@ public Map<String, Operand<?>> call(Ops tf, Map<String, Operand<?>> arguments) {
374352
}
375353

376354
/**
377-
* Calls the function in an execution environment, adding it's graph as a function if it isn't
378-
* already present. Only works for functions with a single input and output.
355+
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
356+
* works for functions with a single input and output.
379357
*
380-
* @param tf the scope to call the function in
358+
* @param tf the scope to call the function in
381359
* @param argument the argument to the call
382360
* @return the output of the function
383361
*/
384362
public Operand<?> call(Ops tf, Operand<?> argument) {
385363
return tf.call(this, argument);
386364
}
387365

388-
/**
389-
* Export this function as a saved model.
390-
*
391-
* <p>This method is convenient shortcut equivalent to
392-
* {@code SavedModel.exporter(exportDir).withFunction(this).export()}
393-
*
394-
* @param exportDir directory where to export the saved model
395-
* @throws IOException if saved model or variable state cannot be written on disk
396-
*/
397-
public void save(String exportDir) throws IOException {
398-
SavedModelBundle.exporter(exportDir).withFunction(this).export();
399-
}
400-
401366
TF_Function nativeHandle() {
402367
if (nativeFunction.getNativeHandle().isNull()) {
403368
throw new IllegalStateException("Function has been closed");
@@ -414,8 +379,8 @@ TF_Function nativeHandle() {
414379
}
415380

416381
/**
417-
* Detects the signature from the handle. Does not close passed functions. All passed functions
418-
* should have deallocators.
382+
* Detects the signature from the handle. Does not close passed functions. All passed functions should have
383+
* deallocators.
419384
*/
420385
static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
421386
Collection<NativeFunction> availableFunctions) {
@@ -524,11 +489,11 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction,
524489
}
525490

526491
/**
527-
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
528-
* how to enable XLA JIT is extremely non-obvious.
492+
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
493+
* JIT is extremely non-obvious.
529494
* <p>
530-
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
531-
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
495+
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
496+
* 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
532497
*/
533498
private void makeJit() {
534499
try (PointerScope scope = new PointerScope()) {
@@ -599,18 +564,18 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
599564
Reference ref = graph.ref()) {
600565
TF_Status status = TF_Status.newStatus();
601566

602-
List<Operand<?>> inputs = signature.getInputs().values().stream()
603-
.map((x) -> graph.outputOrThrow(x.name))
567+
List<Operand<?>> inputs = signature.getInputs().entrySet().stream()
568+
.map((x) -> CallableFunction.validateDescription(x.getValue(), graph, x.getKey(), "Input"))
604569
.collect(Collectors.toList());
605570

606-
List<Operand<?>> outputs = signature.getOutputs().values().stream()
607-
.map((x) -> graph.outputOrThrow(x.name))
571+
List<Operand<?>> outputs = signature.getOutputs().entrySet().stream()
572+
.map((x) -> CallableFunction.validateDescription(x.getValue(), graph, x.getKey(), "Output"))
608573
.collect(Collectors.toList());
609574

610575
List<GraphOperation> ops = new ArrayList<>(
611576
graph.completeSubgraph(new HashSet<>(inputs), new HashSet<>(outputs)));
612577

613-
inputs.forEach(input -> ops.remove(input.op()));
578+
inputs.forEach(input -> ops.remove((GraphOperation) input.op()));
614579

615580
ops.forEach(x -> {
616581
if (x.type().equals(Placeholder.OP_NAME) || x.type()

0 commit comments

Comments
 (0)