18
18
import static org .tensorflow .internal .c_api .global .tensorflow .TF_FunctionSetAttrValueProto ;
19
19
import static org .tensorflow .internal .c_api .global .tensorflow .TF_GraphToFunction ;
20
20
21
- import java .io .IOException ;
22
21
import java .util .ArrayList ;
23
22
import java .util .Arrays ;
24
23
import java .util .Collection ;
65
64
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
66
65
* }</pre>
67
66
*/
68
- public class ConcreteFunction implements AutoCloseable {
67
+ public class ConcreteFunction implements AutoCloseable , CallableFunction {
69
68
70
69
71
70
/**
72
71
* Creates a function by building a new graph.
73
72
*
74
73
* <p>The {@code functionBuilder} must initialize the function graph from the provided
75
- * {@link Ops} instance and return a valid signature that will be used to feed the input tensors
76
- * and fetch the output tensors on execution.
74
+ * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output
75
+ * tensors on execution.
77
76
*
78
77
* <p>The function will be the owner of the new graph and its resulting session. Therefore,
79
- * the function must be enclosed properly with a try-with-resources block to guarantee that all
80
- * native resources will be freed once the function is discarded. For example:
78
+ * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will
79
+ * be freed once the function is discarded. For example:
81
80
*
82
81
* <pre>{@code
83
82
* public class MyModel {
@@ -112,8 +111,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
112
111
* Create a function from a signature and an existing graph.
113
112
*
114
113
* <p>The function will keep the ownership of the session used to run the graph but not
115
- * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the
116
- * function. For example:
114
+ * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For
115
+ * example:
117
116
*
118
117
* <pre>{@code
119
118
* try (Graph g = new Graph()) {
@@ -130,7 +129,7 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
130
129
* }</pre>
131
130
*
132
131
* @param signature signature of the function to create
133
- * @param graph a valid and initialized graph
132
+ * @param graph a valid and initialized graph
134
133
* @return a new function
135
134
*/
136
135
public static ConcreteFunction create (Signature signature , Graph graph ) {
@@ -141,8 +140,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
141
140
* Create a function from a signature and a valid graph session.
142
141
*
143
142
* <p>The function will not own the session nor its graph, meaning that their lifetime
144
- * can extend beyond the scope of the function. Therefore the function does not need to be closed
145
- * after its usage. For example:
143
+ * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For
144
+ * example:
146
145
*
147
146
* <pre>{@code
148
147
* try (Graph g = new Graph()) {
@@ -164,7 +163,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
164
163
* }</pre>
165
164
*
166
165
* @param signature signature of the function to create
167
- * @param session a valid session to an initialized graph
166
+ * @param session a valid session to an initialized graph
168
167
* @return a new function
169
168
*/
170
169
public static ConcreteFunction create (Signature signature , Session session ) {
@@ -174,6 +173,7 @@ public static ConcreteFunction create(Signature signature, Session session) {
174
173
/**
175
174
* Returns the signature of this function
176
175
*/
176
+ @ Override
177
177
public Signature signature () {
178
178
return signature ;
179
179
}
@@ -220,10 +220,10 @@ public String toString() {
220
220
221
221
222
222
/**
223
- * Calls the function in an execution environment, adding it's graph as a function if it isn't
224
- * already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
223
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
224
+ * inputs and outputs are keyed by the names set in the {@code Signature}.
225
225
*
226
- * @param scope the scope to call the function in
226
+ * @param scope the scope to call the function in
227
227
* @param arguments the arguments to the call
228
228
* @return the outputs of the function
229
229
*/
@@ -235,12 +235,17 @@ public Map<String, Operand<?>> call(Scope scope,
235
235
236
236
int i = 0 ;
237
237
for (String inputName : signature ().inputNames ()) {
238
- Operand <?> input = arguments .get (inputName );
239
- if (input == null ) {
238
+ if (!arguments .containsKey (inputName )) {
240
239
throw new IllegalArgumentException (
241
240
"Function " + signature ().methodName () + " has parameter \" " + inputName
242
241
+ "\" , but no argument was passed for it." );
243
242
}
243
+
244
+ Operand <?> input = arguments .get (inputName );
245
+ if (input == null ) {
246
+ throw new IllegalArgumentException (
247
+ "Can't pass null as an argument to a function. Argument \" " + inputName + "\" was null." );
248
+ }
244
249
inputs [i ] = input .asOutput ();
245
250
i ++;
246
251
}
@@ -288,10 +293,10 @@ public Map<String, Operand<?>> call(Scope scope,
288
293
}
289
294
290
295
/**
291
- * Calls the function in an execution environment, adding it's graph as a function if it isn't
292
- * already present. Only works for functions with a single input and output.
296
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
297
+ * works for functions with a single input and output.
293
298
*
294
- * @param scope the scope to call the function in
299
+ * @param scope the scope to call the function in
295
300
* @param argument the argument to the call
296
301
* @return the output of the function
297
302
*/
@@ -316,18 +321,8 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
316
321
return call (scope , inputMap ).get (outputName );
317
322
}
318
323
319
- /**
320
- * Invokes a function using the default eager session.
321
- *
322
- * <p>Caller is responsible for closing all Tensors.
323
- *
324
- * @param arguments list of tensors to pass in input to the function, mapped by their signature
325
- * name
326
- * @return output tensors resulting from the execution of the function, mapped by their signature
327
- * name
328
- */
329
- public Map <String , Tensor > call (Map <String , Tensor > arguments )
330
- throws IllegalArgumentException {
324
+ @ Override
325
+ public Map <String , Tensor > call (Map <String , Tensor > arguments ) {
331
326
//FIXME need to manage input/output operand lifetimes
332
327
Ops tf = Ops .create ();
333
328
Map <String , Operand <?>> inputs = new LinkedHashMap <>(arguments .size ());
@@ -345,27 +340,10 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
345
340
}
346
341
347
342
/**
348
- * Invokes a function with a single input and output using the default eager session.
343
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
344
+ * inputs and outputs are keyed by the names set in the {@code Signature}.
349
345
*
350
- * <p>Caller is responsible for closing all Tensors.
351
- *
352
- * @param tensor input tensor
353
- * @return output tensor
354
- * @throws IllegalArgumentException if there are multiple input or output parameters defined in
355
- * the function
356
- */
357
- public Tensor call (Tensor tensor ) throws IllegalArgumentException {
358
- Ops tf = Ops .create ();
359
- Operand <?> argument = tf .constantOf ((TType ) tensor );
360
- Operand <?> output = call (tf , argument );
361
- return output .asTensor ();
362
- }
363
-
364
- /**
365
- * Calls the function in an execution environment, adding it's graph as a function if it isn't
366
- * already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
367
- *
368
- * @param tf the scope to call the function in
346
+ * @param tf the scope to call the function in
369
347
* @param arguments the arguments to the call
370
348
* @return the outputs of the function
371
349
*/
@@ -374,30 +352,17 @@ public Map<String, Operand<?>> call(Ops tf, Map<String, Operand<?>> arguments) {
374
352
}
375
353
376
354
/**
377
- * Calls the function in an execution environment, adding it's graph as a function if it isn't
378
- * already present. Only works for functions with a single input and output.
355
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
356
+ * works for functions with a single input and output.
379
357
*
380
- * @param tf the scope to call the function in
358
+ * @param tf the scope to call the function in
381
359
* @param argument the argument to the call
382
360
* @return the output of the function
383
361
*/
384
362
public Operand <?> call (Ops tf , Operand <?> argument ) {
385
363
return tf .call (this , argument );
386
364
}
387
365
388
- /**
389
- * Export this function as a saved model.
390
- *
391
- * <p>This method is convenient shortcut equivalent to
392
- * {@code SavedModel.exporter(exportDir).withFunction(this).export()}
393
- *
394
- * @param exportDir directory where to export the saved model
395
- * @throws IOException if saved model or variable state cannot be written on disk
396
- */
397
- public void save (String exportDir ) throws IOException {
398
- SavedModelBundle .exporter (exportDir ).withFunction (this ).export ();
399
- }
400
-
401
366
TF_Function nativeHandle () {
402
367
if (nativeFunction .getNativeHandle ().isNull ()) {
403
368
throw new IllegalStateException ("Function has been closed" );
@@ -414,8 +379,8 @@ TF_Function nativeHandle() {
414
379
}
415
380
416
381
/**
417
- * Detects the signature from the handle. Does not close passed functions. All passed functions
418
- * should have deallocators.
382
+ * Detects the signature from the handle. Does not close passed functions. All passed functions should have
383
+ * deallocators.
419
384
*/
420
385
static ConcreteFunction fromNativeHandle (NativeFunction nativeFunction ,
421
386
Collection <NativeFunction > availableFunctions ) {
@@ -524,11 +489,11 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction,
524
489
}
525
490
526
491
/**
527
- * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
528
- * how to enable XLA JIT is extremely non-obvious.
492
+ * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
493
+ * JIT is extremely non-obvious.
529
494
* <p>
530
- * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
531
- * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
495
+ * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
496
+ * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
532
497
*/
533
498
private void makeJit () {
534
499
try (PointerScope scope = new PointerScope ()) {
@@ -599,18 +564,18 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
599
564
Reference ref = graph .ref ()) {
600
565
TF_Status status = TF_Status .newStatus ();
601
566
602
- List <Operand <?>> inputs = signature .getInputs ().values ().stream ()
603
- .map ((x ) -> graph . outputOrThrow (x .name ))
567
+ List <Operand <?>> inputs = signature .getInputs ().entrySet ().stream ()
568
+ .map ((x ) -> CallableFunction . validateDescription (x .getValue (), graph , x . getKey (), "Input" ))
604
569
.collect (Collectors .toList ());
605
570
606
- List <Operand <?>> outputs = signature .getOutputs ().values ().stream ()
607
- .map ((x ) -> graph . outputOrThrow (x .name ))
571
+ List <Operand <?>> outputs = signature .getOutputs ().entrySet ().stream ()
572
+ .map ((x ) -> CallableFunction . validateDescription (x .getValue (), graph , x . getKey (), "Output" ))
608
573
.collect (Collectors .toList ());
609
574
610
575
List <GraphOperation > ops = new ArrayList <>(
611
576
graph .completeSubgraph (new HashSet <>(inputs ), new HashSet <>(outputs )));
612
577
613
- inputs .forEach (input -> ops .remove (input .op ()));
578
+ inputs .forEach (input -> ops .remove (( GraphOperation ) input .op ()));
614
579
615
580
ops .forEach (x -> {
616
581
if (x .type ().equals (Placeholder .OP_NAME ) || x .type ()
0 commit comments