@@ -72,12 +72,12 @@ public class ConcreteFunction implements AutoCloseable {
72
72
* Creates a function by building a new graph.
73
73
*
74
74
* <p>The {@code functionBuilder} must initialize the function graph from the provided
75
- * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output
76
- * tensors on execution.
75
+ * {@link Ops} instance and return a valid signature that will be used to feed the input tensors
76
+ * and fetch the output tensors on execution.
77
77
*
78
78
* <p>The function will be the owner of the new graph and its resulting session. Therefore,
79
- * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will
80
- * be freed once the function is discarded. For example:
79
+ * the function must be enclosed properly with a try-with-resources block to guarantee that all
80
+ * native resources will be freed once the function is discarded. For example:
81
81
*
82
82
* <pre>{@code
83
83
* public class MyModel {
@@ -112,8 +112,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
112
112
* Create a function from a signature and an existing graph.
113
113
*
114
114
* <p>The function will keep the ownership of the session used to run the graph but not
115
- * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For
116
- * example:
115
+ * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the
116
+ * function. For example:
117
117
*
118
118
* <pre>{@code
119
119
* try (Graph g = new Graph()) {
@@ -130,7 +130,7 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
130
130
* }</pre>
131
131
*
132
132
* @param signature signature of the function to create
133
- * @param graph a valid and initialized graph
133
+ * @param graph a valid and initialized graph
134
134
* @return a new function
135
135
*/
136
136
public static ConcreteFunction create (Signature signature , Graph graph ) {
@@ -141,8 +141,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
141
141
* Create a function from a signature and a valid graph session.
142
142
*
143
143
* <p>The function will not own the session nor its graph, meaning that their lifetime
144
- * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For
145
- * example:
144
+ * can extend beyond the scope of the function. Therefore the function does not need to be closed
145
+ * after its usage. For example:
146
146
*
147
147
* <pre>{@code
148
148
* try (Graph g = new Graph()) {
@@ -164,7 +164,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
164
164
* }</pre>
165
165
*
166
166
* @param signature signature of the function to create
167
- * @param session a valid session to an initialized graph
167
+ * @param session a valid session to an initialized graph
168
168
* @return a new function
169
169
*/
170
170
public static ConcreteFunction create (Signature signature , Session session ) {
@@ -220,10 +220,10 @@ public String toString() {
220
220
221
221
222
222
/**
223
- * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
224
- * inputs and outputs are keyed by the names set in the {@code Signature}.
223
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't
224
+ * already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
225
225
*
226
- * @param scope the scope to call the function in
226
+ * @param scope the scope to call the function in
227
227
* @param arguments the arguments to the call
228
228
* @return the outputs of the function
229
229
*/
@@ -276,7 +276,8 @@ public Map<String, Operand<?>> call(Scope scope,
276
276
String outputName = outputNames .get (i );
277
277
278
278
if (i > outputList .size ()) {
279
- throw new IllegalStateException ("Somehow, not all required outputs were returned from the function" );
279
+ throw new IllegalStateException (
280
+ "Somehow, not all required outputs were returned from the function" );
280
281
}
281
282
282
283
Operand <?> output = outputList .get (i );
@@ -287,10 +288,10 @@ public Map<String, Operand<?>> call(Scope scope,
287
288
}
288
289
289
290
/**
290
- * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
291
- * works for functions with a single input and output.
291
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't
292
+ * already present. Only works for functions with a single input and output.
292
293
*
293
- * @param scope the scope to call the function in
294
+ * @param scope the scope to call the function in
294
295
* @param argument the argument to the call
295
296
* @return the output of the function
296
297
*/
@@ -320,8 +321,10 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
320
321
*
321
322
* <p>Caller is responsible for closing all Tensors.
322
323
*
323
- * @param arguments list of tensors to pass in input to the function, mapped by their signature name
324
- * @return output tensors resulting from the execution of the function, mapped by their signature name
324
+ * @param arguments list of tensors to pass in input to the function, mapped by their signature
325
+ * name
326
+ * @return output tensors resulting from the execution of the function, mapped by their signature
327
+ * name
325
328
*/
326
329
public Map <String , Tensor > call (Map <String , Tensor > arguments )
327
330
throws IllegalArgumentException {
@@ -348,7 +351,8 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
348
351
*
349
352
* @param tensor input tensor
350
353
* @return output tensor
351
- * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function
354
+ * @throws IllegalArgumentException if there are multiple input or output parameters defined in
355
+ * the function
352
356
*/
353
357
public Tensor call (Tensor tensor ) throws IllegalArgumentException {
354
358
Ops tf = Ops .create ();
@@ -358,10 +362,10 @@ public Tensor call(Tensor tensor) throws IllegalArgumentException {
358
362
}
359
363
360
364
/**
361
- * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
362
- * inputs and outputs are keyed by the names set in the {@code Signature}.
365
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't
366
+ * already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
363
367
*
364
- * @param tf the scope to call the function in
368
+ * @param tf the scope to call the function in
365
369
* @param arguments the arguments to the call
366
370
* @return the outputs of the function
367
371
*/
@@ -370,10 +374,10 @@ public Map<String, Operand<?>> call(Ops tf, Map<String, Operand<?>> arguments) {
370
374
}
371
375
372
376
/**
373
- * Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
374
- * works for functions with a single input and output.
377
+ * Calls the function in an execution environment, adding it's graph as a function if it isn't
378
+ * already present. Only works for functions with a single input and output.
375
379
*
376
- * @param tf the scope to call the function in
380
+ * @param tf the scope to call the function in
377
381
* @param argument the argument to the call
378
382
* @return the output of the function
379
383
*/
@@ -401,17 +405,23 @@ TF_Function nativeHandle() {
401
405
return nativeFunction .getNativeHandle ();
402
406
}
403
407
404
- ConcreteFunction (Signature signature , NativeFunction nativeFunction , Collection <NativeFunction > availableFunctions ) {
408
+ /**
409
+ * All native functions should have deallocators registered
410
+ */
411
+ ConcreteFunction (Signature signature , NativeFunction nativeFunction ,
412
+ Collection <NativeFunction > availableFunctions ) {
405
413
this (signature , nativeFunction , nativeFunction .getAllDependencies (availableFunctions ));
406
414
}
407
415
408
416
/**
409
- * Detects the signature from the handle
417
+ * Detects the signature from the handle. Does not close passed functions. All passed functions
418
+ * should have deallocators.
410
419
*/
411
420
static ConcreteFunction fromNativeHandle (NativeFunction nativeFunction ,
412
421
Collection <NativeFunction > availableFunctions ) {
413
422
414
- Signature .Builder builder = Signature .builder ().methodName (nativeFunction .getFunctionDef ().getSignature ().getName ())
423
+ Signature .Builder builder = Signature .builder ()
424
+ .methodName (nativeFunction .getFunctionDef ().getSignature ().getName ())
415
425
.key (nativeFunction .getName ());
416
426
417
427
for (ArgDef input : nativeFunction .getFunctionDef ().getSignature ().getInputArgList ()) {
@@ -448,19 +458,26 @@ static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
448
458
private final DataType [] inputDtypes ;
449
459
private final DataType [] outputDtypes ;
450
460
451
- private ConcreteFunction (Signature signature , NativeFunction nativeFunction , Set <TF_Function > dependencies ) {
461
+
462
+ /**
463
+ * All native functions should have deallocators registered
464
+ */
465
+ private ConcreteFunction (Signature signature , NativeFunction nativeFunction ,
466
+ Set <TF_Function > dependencies ) {
452
467
this .signature = signature ;
453
468
this .nativeFunction = nativeFunction ;
454
469
this .dependencies = Collections .unmodifiableSet (dependencies );
455
470
456
- if (this .signature .getInputs ().size () != nativeFunction .getFunctionDef ().getSignature ().getInputArgCount ()) {
471
+ if (this .signature .getInputs ().size () != nativeFunction .getFunctionDef ().getSignature ()
472
+ .getInputArgCount ()) {
457
473
throw new IllegalArgumentException (
458
474
"Signature must have the same number of inputs as the native function. Expected "
459
475
+ nativeFunction .getFunctionDef ().getSignature ().getInputArgCount () + ", got "
460
476
+ this .signature .getInputs ().size ());
461
477
}
462
478
463
- if (this .signature .getOutputs ().size () != nativeFunction .getFunctionDef ().getSignature ().getOutputArgCount ()) {
479
+ if (this .signature .getOutputs ().size () != nativeFunction .getFunctionDef ().getSignature ()
480
+ .getOutputArgCount ()) {
464
481
throw new IllegalArgumentException (
465
482
"New signature must have the same number of outputs as the native function. Expected "
466
483
+ nativeFunction .getFunctionDef ().getSignature ().getOutputArgCount () + ", got "
@@ -471,7 +488,8 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
471
488
.toArray (DataType []::new );
472
489
473
490
List <DataType > inputs = Arrays .asList (inputDtypes );
474
- List <DataType > nativeInputs = nativeFunction .getFunctionDef ().getSignature ().getInputArgList ().stream ()
491
+ List <DataType > nativeInputs = nativeFunction .getFunctionDef ().getSignature ().getInputArgList ()
492
+ .stream ()
475
493
.map (ArgDef ::getType )
476
494
.collect (Collectors .toList ());
477
495
@@ -481,10 +499,12 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
481
499
+ nativeInputs + ", got " + inputs );
482
500
}
483
501
484
- outputDtypes = signature ().getOutputs ().values ().stream ().map (x -> x .dataType ).toArray (DataType []::new );
502
+ outputDtypes = signature ().getOutputs ().values ().stream ().map (x -> x .dataType )
503
+ .toArray (DataType []::new );
485
504
486
505
List <DataType > outputs = Arrays .asList (outputDtypes );
487
- List <DataType > nativeOutputs = nativeFunction .getFunctionDef ().getSignature ().getOutputArgList ().stream ()
506
+ List <DataType > nativeOutputs = nativeFunction .getFunctionDef ().getSignature ().getOutputArgList ()
507
+ .stream ()
488
508
.map (ArgDef ::getType )
489
509
.collect (Collectors .toList ());
490
510
@@ -498,25 +518,26 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
498
518
try (PointerScope scope = new PointerScope ()) {
499
519
this .scope = scope ;
500
520
scope .extend ();
501
- this .nativeFunction .getNativeHandle (). withDeallocatorInScope ( );
502
- this .dependencies .forEach (TF_Function :: withDeallocatorInScope );
521
+ scope . attach ( this .nativeFunction .getNativeHandle ());
522
+ this .dependencies .forEach (scope :: attach );
503
523
}
504
524
}
505
525
506
526
/**
507
- * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
508
- * JIT is extremely non-obvious.
509
- *
510
- * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
511
- * 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
527
+ * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
528
+ * how to enable XLA JIT is extremely non-obvious.
529
+ * <p>
530
+ * Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
531
+ * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
512
532
*/
513
533
private void makeJit () {
514
534
try (PointerScope scope = new PointerScope ()) {
515
535
byte [] bytes = AttrValue .newBuilder ().setB (true ).build ().toByteArray ();
516
536
BytePointer trueValue = new BytePointer (bytes );
517
537
518
538
TF_Status status1 = TF_Status .newStatus ();
519
- TF_FunctionSetAttrValueProto (nativeHandle (), "_XlaMustCompile" , trueValue , bytes .length , status1 );
539
+ TF_FunctionSetAttrValueProto (nativeHandle (), "_XlaMustCompile" , trueValue , bytes .length ,
540
+ status1 );
520
541
status1 .throwExceptionIfNotOK ();
521
542
522
543
TF_Status status2 = TF_Status .newStatus ();
@@ -592,9 +613,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
592
613
inputs .forEach (input -> ops .remove (input .op ()));
593
614
594
615
ops .forEach (x -> {
595
- if (x .type ().equals (Placeholder .OP_NAME ) || x .type ().equals (PlaceholderWithDefault .OP_NAME )){
596
- throw new IllegalArgumentException ("Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), "
597
- + "they also depend on \" " + x + "\" " );
616
+ if (x .type ().equals (Placeholder .OP_NAME ) || x .type ()
617
+ .equals (PlaceholderWithDefault .OP_NAME )) {
618
+ throw new IllegalArgumentException (
619
+ "Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), "
620
+ + "they also depend on \" " + x + "\" " );
598
621
}
599
622
});
600
623
@@ -629,12 +652,15 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
629
652
resolveToOutput (graph , outputs ),
630
653
null ,
631
654
null ,
632
- new BytePointer (signature .methodName () != null ? signature .methodName () : "Method " + signature .key ()),
655
+ new BytePointer (signature .methodName () != null ? signature .methodName ()
656
+ : "Method " + signature .key ()),
633
657
status
634
658
);
635
659
660
+ handle .withDeallocator ();
636
661
status .throwExceptionIfNotOK ();
637
- return new ConcreteFunction (signature , new NativeFunction (handle ), graph .getNativeFunctions ());
662
+ return new ConcreteFunction (signature , new NativeFunction (handle ),
663
+ graph .getNativeFunctions (scope ));
638
664
}
639
665
}
640
666
}
0 commit comments