15
15
*/
16
16
package org .tensorflow ;
17
17
18
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_FunctionName ;
19
18
import static org .tensorflow .internal .c_api .global .tensorflow .TF_FunctionSetAttrValueProto ;
20
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_FunctionToFunctionDef ;
21
19
import static org .tensorflow .internal .c_api .global .tensorflow .TF_GraphToFunction ;
22
- import static org .tensorflow .internal .c_api .global .tensorflow .TF_OperationGetAttrValueProto ;
23
20
24
- import com .google .protobuf .InvalidProtocolBufferException ;
25
21
import java .io .IOException ;
26
22
import java .util .ArrayList ;
23
+ import java .util .Collection ;
27
24
import java .util .Collections ;
28
25
import java .util .HashSet ;
29
26
import java .util .LinkedHashMap ;
30
- import java .util .LinkedHashSet ;
31
27
import java .util .List ;
32
28
import java .util .Map ;
33
29
import java .util .Set ;
38
34
import org .bytedeco .javacpp .PointerPointer ;
39
35
import org .bytedeco .javacpp .PointerScope ;
40
36
import org .tensorflow .Graph .Reference ;
41
- import org .tensorflow .internal .c_api .TF_Buffer ;
42
37
import org .tensorflow .internal .c_api .TF_Function ;
43
38
import org .tensorflow .internal .c_api .TF_Operation ;
44
39
import org .tensorflow .internal .c_api .TF_Output ;
49
44
import org .tensorflow .proto .framework .AttrValue ;
50
45
import org .tensorflow .proto .framework .DataType ;
51
46
import org .tensorflow .proto .framework .FunctionDef ;
52
- import org .tensorflow .proto .framework .NodeDef ;
53
47
import org .tensorflow .proto .framework .OpDef .ArgDef ;
54
48
import org .tensorflow .proto .framework .SignatureDef ;
55
49
import org .tensorflow .proto .framework .TensorInfo ;
@@ -186,9 +180,35 @@ public Signature signature() {
186
180
* Get the name of the function.
187
181
*/
188
182
public String getNativeFunctionName () {
189
- try (PointerScope scope = new PointerScope ()) {
190
- return TF_FunctionName (nativeHandle ()).getString ();
191
- }
183
+ return nativeFunction .getName ();
184
+ }
185
+
186
+ /**
187
+ * Get the {@link FunctionDef} proto.
188
+ */
189
+ public FunctionDef getFunctionDef () {
190
+ return nativeFunction .getFunctionDef ();
191
+ }
192
+
193
+ /**
194
+ * Get whether the function is stateful.
195
+ */
196
+ public boolean isStateful () {
197
+ return nativeFunction .isStateful ();
198
+ }
199
+
200
+ Set <TF_Function > getDependencies () {
201
+ return dependencies ;
202
+ }
203
+
204
+ @ Override
205
+ public void close () {
206
+ scope .close ();
207
+ }
208
+
209
+ @ Override
210
+ public String toString () {
211
+ return signature .toString ();
192
212
}
193
213
194
214
public static final String CALL_OP = "PartitionedCall" ;
@@ -214,7 +234,7 @@ public Map<String, Operand<?>> call(Scope scope,
214
234
String displayName = Scope .isValidOpName (name ) ? name : "FunctionCall" ;
215
235
216
236
OperationBuilder opBuilder = scope .env ()
217
- .opBuilder (stateful ? STATEFUL_CALL_OP : CALL_OP , scope .makeOpName (displayName ));
237
+ .opBuilder (isStateful () ? STATEFUL_CALL_OP : CALL_OP , scope .makeOpName (displayName ));
218
238
219
239
opBuilder .addInputList (inputList .stream ().map (Operand ::asOutput ).toArray (Output []::new ));
220
240
@@ -357,20 +377,6 @@ public void save(String exportDir) throws IOException {
357
377
SavedModelBundle .exporter (exportDir ).withFunction (this ).export ();
358
378
}
359
379
360
- public boolean isStateful () {
361
- return stateful ;
362
- }
363
-
364
- @ Override
365
- public void close () {
366
- scope .close ();
367
- }
368
-
369
- @ Override
370
- public String toString () {
371
- return signature .toString ();
372
- }
373
-
374
380
/**
375
381
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
376
382
* JIT is extremely non-obvious.
@@ -394,10 +400,10 @@ private void makeJit() {
394
400
}
395
401
396
402
TF_Function nativeHandle () {
397
- if (nativeHandle .isNull ()) {
403
+ if (nativeFunction . getNativeHandle () .isNull ()) {
398
404
throw new IllegalStateException ("Function has been closed" );
399
405
}
400
- return nativeHandle ;
406
+ return nativeFunction . getNativeHandle () ;
401
407
}
402
408
403
409
/**
@@ -410,43 +416,93 @@ TF_Function gradNativeHandle() {
410
416
}
411
417
412
418
private final Signature signature ;
413
- private final TF_Function nativeHandle ;
419
+ private final NativeFunction nativeFunction ;
414
420
private final PointerScope scope ;
415
- private final boolean stateful ;
421
+ private final Set <TF_Function > dependencies ;
422
+
423
+ ConcreteFunction (Signature signature , NativeFunction nativeFunction , Collection <NativeFunction > availableFunctions ) {
424
+ this (signature , nativeFunction , nativeFunction .getAllDependencies (availableFunctions ));
425
+ }
426
+
427
+ private static boolean dataTypesMatch (List <DataType > a , List <DataType > b ) {
428
+ if (a .size () != b .size ()) {
429
+ return false ;
430
+ }
431
+
432
+ for (int i = 0 ; i < a .size (); i ++) {
433
+ DataType aType = a .get (i );
434
+ DataType bType = b .get (i );
435
+
436
+ if (aType != DataType .DT_INVALID && bType != DataType .DT_INVALID && !a .equals (b )) {
437
+ return false ;
438
+ }
439
+ }
416
440
417
- ConcreteFunction (Signature signature , TF_Function nativeHandle , boolean stateful ) {
441
+ return true ;
442
+ }
443
+
444
+ private ConcreteFunction (Signature signature , NativeFunction nativeFunction , Set <TF_Function > dependencies ) {
418
445
this .signature = signature ;
446
+ this .nativeFunction = nativeFunction ;
447
+ this .dependencies = Collections .unmodifiableSet (dependencies );
448
+
449
+ if (this .signature .getInputs ().size () != nativeFunction .getFunctionDef ().getSignature ().getInputArgCount ()) {
450
+ throw new IllegalArgumentException (
451
+ "Signature must have the same number of inputs as the native function. Expected "
452
+ + nativeFunction .getFunctionDef ().getSignature ().getInputArgCount () + ", got "
453
+ + this .signature .getInputs ().size ());
454
+ }
455
+
456
+ if (this .signature .getOutputs ().size () != nativeFunction .getFunctionDef ().getSignature ().getOutputArgCount ()) {
457
+ throw new IllegalArgumentException (
458
+ "New signature must have the same number of outputs as the native function. Expected "
459
+ + nativeFunction .getFunctionDef ().getSignature ().getOutputArgCount () + ", got "
460
+ + this .signature .getOutputs ().size ());
461
+ }
462
+
463
+ List <DataType > inputs = this .signature .getInputs ().values ().stream ().map (x -> x .dataType )
464
+ .collect (Collectors .toList ());
465
+ List <DataType > nativeInputs = nativeFunction .getFunctionDef ().getSignature ().getInputArgList ().stream ()
466
+ .map (ArgDef ::getType )
467
+ .collect (Collectors .toList ());
468
+
469
+ if (!dataTypesMatch (inputs , nativeInputs )) {
470
+ throw new IllegalArgumentException (
471
+ "Data types of the signature's inputs must match the native function's (in order). Expected "
472
+ + nativeInputs + ", got " + inputs );
473
+ }
474
+
475
+ List <DataType > outputs = this .signature .getOutputs ().values ().stream ().map (x -> x .dataType )
476
+ .collect (Collectors .toList ());
477
+ List <DataType > nativeOutputs = nativeFunction .getFunctionDef ().getSignature ().getOutputArgList ().stream ()
478
+ .map (ArgDef ::getType )
479
+ .collect (Collectors .toList ());
480
+
481
+ if (!dataTypesMatch (outputs , nativeOutputs )) {
482
+ throw new IllegalArgumentException (
483
+ "Data types of the signature's outputs must match the native function's (in order). Expected "
484
+ + nativeOutputs + ", got "
485
+ + outputs );
486
+ }
487
+
419
488
try (PointerScope scope = new PointerScope ()) {
420
- scope .extend ();
421
- this .nativeHandle = nativeHandle .withDeallocator ();
422
- scope .attach (nativeHandle );
423
489
this .scope = scope ;
490
+ scope .extend ();
491
+ this .nativeFunction .getNativeHandle ().withDeallocatorInScope ();
492
+ this .dependencies .forEach (TF_Function ::withDeallocatorInScope );
424
493
}
425
- this .stateful = stateful ;
426
494
}
427
495
428
496
/**
429
497
* Detects the signature from the handle
430
498
*/
431
- static ConcreteFunction fromNativeHandle (TF_Function function ) {
499
+ static ConcreteFunction fromNativeHandle (NativeFunction nativeFunction ,
500
+ Collection <NativeFunction > availableFunctions ) {
432
501
433
- FunctionDef funcDef = null ;
434
- try (PointerScope scope = new PointerScope ()) {
435
- TF_Buffer funcDefBuffer = TF_Buffer .newBuffer ();
436
- TF_Status status2 = TF_Status .newStatus ();
437
- TF_FunctionToFunctionDef (function , funcDefBuffer , status2 );
438
- status2 .throwExceptionIfNotOK ();
439
- try {
440
- funcDef = FunctionDef .parseFrom (funcDefBuffer .dataAsByteBuffer ());
441
- } catch (InvalidProtocolBufferException e ) {
442
- throw new IllegalStateException ("Failed to parse FunctionDef proto" , e );
443
- }
444
- }
445
-
446
- Signature .Builder builder = Signature .builder ().methodName (funcDef .getSignature ().getName ())
447
- .key (TF_FunctionName (function ).getString ());
502
+ Signature .Builder builder = Signature .builder ().methodName (nativeFunction .getFunctionDef ().getSignature ().getName ())
503
+ .key (nativeFunction .getName ());
448
504
449
- for (ArgDef input : funcDef .getSignature ().getInputArgList ()) {
505
+ for (ArgDef input : nativeFunction . getFunctionDef () .getSignature ().getInputArgList ()) {
450
506
TensorInfo info = TensorInfo .newBuilder ()
451
507
.setDtype (input .getType ())
452
508
.setTensorShape (TensorShapeProto .newBuilder ().setUnknownRank (true ).build ())
@@ -456,7 +512,7 @@ static ConcreteFunction fromNativeHandle(TF_Function function) {
456
512
builder .input (input .getName (), info );
457
513
}
458
514
459
- for (ArgDef outputDef : funcDef .getSignature ().getOutputArgList ()) {
515
+ for (ArgDef outputDef : nativeFunction . getFunctionDef () .getSignature ().getOutputArgList ()) {
460
516
TensorInfo info = TensorInfo .newBuilder ()
461
517
.setDtype (outputDef .getType ())
462
518
.setTensorShape (TensorShapeProto .newBuilder ().setUnknownRank (true ).build ())
@@ -468,8 +524,8 @@ static ConcreteFunction fromNativeHandle(TF_Function function) {
468
524
469
525
return new ConcreteFunction (
470
526
builder .build (),
471
- function ,
472
- funcDef . getNodeDefList (). stream (). anyMatch ( x -> TensorFlow . isOpStateful ( x . getOp ()))
527
+ nativeFunction ,
528
+ availableFunctions
473
529
);
474
530
}
475
531
@@ -559,91 +615,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
559
615
);
560
616
561
617
status .throwExceptionIfNotOK ();
562
- return new ConcreteFunction (signature , handle , ops . stream (). anyMatch ( x -> TensorFlow . isOpStateful ( x . type ()) ));
618
+ return new ConcreteFunction (signature , new NativeFunction ( handle ), graph . getNativeFunctions ( ));
563
619
}
564
620
}
565
621
566
622
ConcreteFunction withNewSignature (Signature signature ) {
567
- if (this .signature .getInputs ().size () != signature .getInputs ().size ()) {
568
- throw new IllegalArgumentException (
569
- "New signature must have the same number of inputs. Expected " + this .signature .getInputs ().size () + ", got "
570
- + signature .getInputs ().size ());
571
- }
572
-
573
- if (this .signature .getOutputs ().size () != signature .getOutputs ().size ()) {
574
- throw new IllegalArgumentException (
575
- "New signature must have the same number of inputs. Expected " + this .signature .getInputs ().size () + ", got "
576
- + signature .getInputs ().size ());
577
- }
578
-
579
- List <DataType > inputs = this .signature .getInputs ().values ().stream ().map (x -> x .dataType )
580
- .collect (Collectors .toList ());
581
- List <DataType > newInputs = signature .getInputs ().values ().stream ().map (x -> x .dataType )
582
- .collect (Collectors .toList ());
583
-
584
- if (!inputs .equals (newInputs )) {
585
- throw new IllegalArgumentException (
586
- "Data types of the new signature's inputs (in order) must match. Expected " + inputs + ", got " + newInputs );
587
- }
588
-
589
- List <DataType > outputs = this .signature .getOutputs ().values ().stream ().map (x -> x .dataType )
590
- .collect (Collectors .toList ());
591
- List <DataType > newOutputs = signature .getOutputs ().values ().stream ().map (x -> x .dataType )
592
- .collect (Collectors .toList ());
593
-
594
- if (!outputs .equals (newOutputs )) {
595
- throw new IllegalArgumentException (
596
- "Data types of the new signature's outputs (in order) must match. Expected " + outputs + ", got "
597
- + newOutputs );
598
- }
599
-
600
- return new ConcreteFunction (signature , nativeHandle , stateful );
601
- }
602
-
603
- /**
604
- * Returns the function name if {@code op} is a function call op, or null otherwise.
605
- */
606
- static String findFunctionCall (GraphOperation op ) {
607
- if (op .type ().equals (STATEFUL_CALL_OP ) || op .type ().equals (CALL_OP )) {
608
- try (PointerScope scope = new PointerScope ()) {
609
- TF_Status status = TF_Status .newStatus ();
610
- TF_Buffer buff = TF_Buffer .newBuffer ();
611
- TF_OperationGetAttrValueProto (op .getUnsafeNativeHandle (), "f" , buff , status );
612
- status .throwExceptionIfNotOK ();
613
- AttrValue def = AttrValue .parseFrom (buff .dataAsByteBuffer ());
614
-
615
- return def .getFunc ().getName ();
616
- } catch (InvalidProtocolBufferException e ) {
617
- return null ;
618
- }
619
- }
620
-
621
- return null ;
622
- }
623
-
624
- FunctionDef functionDef () {
625
- try (PointerScope scope = new PointerScope ()) {
626
- TF_Buffer funcDefBuffer = TF_Buffer .newBuffer ();
627
- TF_Status status2 = TF_Status .newStatus ();
628
- TF_FunctionToFunctionDef (nativeHandle (), funcDefBuffer , status2 );
629
- status2 .throwExceptionIfNotOK ();
630
- try {
631
- return FunctionDef .parseFrom (funcDefBuffer .dataAsByteBuffer ());
632
- } catch (InvalidProtocolBufferException e ) {
633
- throw new IllegalStateException ("Failed to parse FunctionDef proto" , e );
634
- }
635
- }
636
- }
637
-
638
- static List <String > findDependencies (FunctionDef def ) {
639
- Set <String > deps = new LinkedHashSet <>();
640
-
641
- for (NodeDef node : def .getNodeDefList ()) {
642
- if (node .getOp ().equals (CALL_OP ) || node .getOp ().equals (STATEFUL_CALL_OP )) {
643
- deps .add (node .getAttrMap ().get ("f" ).getFunc ().getName ());
644
- }
645
- }
646
-
647
- return new ArrayList <>(deps );
623
+ return new ConcreteFunction (signature , nativeFunction , dependencies );
648
624
}
649
625
}
0 commit comments