1
- /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
1
+ /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
2
2
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
6
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- ==============================================================================*/
7
+ http://www.apache.org/licenses/LICENSE-2.0
15
8
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ =======================================================================
15
+ */
16
16
package org .tensorflow ;
17
17
18
18
import static org .tensorflow .internal .c_api .global .tensorflow .TF_AddControlInput ;
63
63
import org .tensorflow .proto .framework .DataType ;
64
64
import org .tensorflow .proto .framework .NameAttrList ;
65
65
66
- /**
67
- * An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}.
68
- */
66
+ /** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */
69
67
public final class GraphOperationBuilder implements OperationBuilder {
70
68
71
69
GraphOperationBuilder (Graph graph , String type , String name ) {
@@ -103,7 +101,8 @@ public GraphOperationBuilder addControlInput(Operation control) {
103
101
}
104
102
105
103
if (control .env () != graph ) {
106
- throw new IllegalArgumentException ("Control input " + control + " was from a different graph, can't use." );
104
+ throw new IllegalArgumentException (
105
+ "Control input " + control + " was from a different graph, can't use." );
107
106
}
108
107
109
108
Graph .Reference r = graph .ref ();
@@ -369,9 +368,12 @@ public OperationBuilder setAttr(String name, ConcreteFunction[] value) {
369
368
}
370
369
371
370
try (Reference r = graph .ref ()) {
372
- setAttrFunctionList (unsafeNativeHandle , name , Arrays .stream (value )
373
- .map (ConcreteFunction ::getNativeFunctionName )
374
- .collect (Collectors .toList ()));
371
+ setAttrFunctionList (
372
+ unsafeNativeHandle ,
373
+ name ,
374
+ Arrays .stream (value )
375
+ .map (ConcreteFunction ::getNativeFunctionName )
376
+ .collect (Collectors .toList ()));
375
377
}
376
378
return this ;
377
379
}
@@ -426,11 +428,16 @@ private static void addInput(TF_OperationDescription handle, TF_Operation opHand
426
428
}
427
429
}
428
430
429
- private static void addInputList (TF_OperationDescription handle , TF_Operation [] opHandles , int [] indices ) {
431
+ private static void addInputList (
432
+ TF_OperationDescription handle , TF_Operation [] opHandles , int [] indices ) {
430
433
requireHandle (handle );
431
434
if (indices .length != opHandles .length ) {
432
- throw new IllegalArgumentException ("mismatch in number of Operations ("
433
- + opHandles .length + ") and output indices (" + indices .length + ") provided" );
435
+ throw new IllegalArgumentException (
436
+ "mismatch in number of Operations ("
437
+ + opHandles .length
438
+ + ") and output indices ("
439
+ + indices .length
440
+ + ") provided" );
434
441
}
435
442
436
443
try (PointerScope scope = new PointerScope ()) {
@@ -444,8 +451,8 @@ private static void addInputList(TF_OperationDescription handle, TF_Operation[]
444
451
445
452
private static void addControlInput (TF_OperationDescription handle , TF_Operation opHandle ) {
446
453
if (opHandle == null || opHandle .isNull ()) {
447
- throw new IllegalStateException ("control input is not valid, "
448
- + "perhaps the Graph containing it has been closed()?" );
454
+ throw new IllegalStateException (
455
+ "control input is not valid, " + "perhaps the Graph containing it has been closed()?" );
449
456
}
450
457
requireHandle (handle );
451
458
TF_AddControlInput (handle , opHandle );
@@ -491,7 +498,8 @@ private static void setAttrBool(TF_OperationDescription handle, String name, boo
491
498
TF_SetAttrBool (handle , name , (byte ) (value ? 1 : 0 ));
492
499
}
493
500
494
- private static void setAttrBoolList (TF_OperationDescription handle , String name , boolean [] value ) {
501
+ private static void setAttrBoolList (
502
+ TF_OperationDescription handle , String name , boolean [] value ) {
495
503
requireHandle (handle );
496
504
try (PointerScope scope = new PointerScope ()) {
497
505
TF_SetAttrBoolList (handle , name , new BytePointer (new BooleanPointer (value )), value .length );
@@ -508,7 +516,8 @@ private static void setAttrTypeList(TF_OperationDescription handle, String name,
508
516
TF_SetAttrTypeList (handle , name , type , type .length );
509
517
}
510
518
511
- private static void setAttrTensor (TF_OperationDescription handle , String name , TF_Tensor tensorHandle ) {
519
+ private static void setAttrTensor (
520
+ TF_OperationDescription handle , String name , TF_Tensor tensorHandle ) {
512
521
requireHandle (handle );
513
522
requireTensor (tensorHandle );
514
523
@@ -519,7 +528,8 @@ private static void setAttrTensor(TF_OperationDescription handle, String name, T
519
528
}
520
529
}
521
530
522
- private static void setAttrTensorList (TF_OperationDescription handle , String name , TF_Tensor [] tensorHandles ) {
531
+ private static void setAttrTensorList (
532
+ TF_OperationDescription handle , String name , TF_Tensor [] tensorHandles ) {
523
533
requireHandle (handle );
524
534
525
535
try (PointerScope scope = new PointerScope ()) {
@@ -530,20 +540,23 @@ private static void setAttrTensorList(TF_OperationDescription handle, String nam
530
540
}
531
541
532
542
TF_Status status = TF_Status .newStatus ();
533
- TF_SetAttrTensorList (handle , new BytePointer (name ), tensors .position (0 ), tensorHandles .length , status );
543
+ TF_SetAttrTensorList (
544
+ handle , new BytePointer (name ), tensors .position (0 ), tensorHandles .length , status );
534
545
status .throwExceptionIfNotOK ();
535
546
}
536
547
}
537
548
538
- private static void setAttrShape (TF_OperationDescription handle , String name , long [] shape , int numDims ) {
549
+ private static void setAttrShape (
550
+ TF_OperationDescription handle , String name , long [] shape , int numDims ) {
539
551
requireHandle (handle );
540
552
541
553
// num_dims and env->GetArrayLength(shape) are assumed to be consistent.
542
554
// i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
543
555
TF_SetAttrShape (handle , name , shape , numDims );
544
556
}
545
557
546
- private static void setAttrShapeList (TF_OperationDescription handle , String name , long [] shapes , int [] numDims ) {
558
+ private static void setAttrShapeList (
559
+ TF_OperationDescription handle , String name , long [] shapes , int [] numDims ) {
547
560
requireHandle (handle );
548
561
549
562
try (PointerScope scope = new PointerScope ()) {
@@ -553,11 +566,13 @@ private static void setAttrShapeList(TF_OperationDescription handle, String name
553
566
shapesPointers .put (i , shapesPointer );
554
567
shapesPointer .position (shapesPointer .position () + numDims [i ] * 8 );
555
568
}
556
- TF_SetAttrShapeList (handle , new BytePointer (name ), shapesPointers , new IntPointer (numDims ), numDims .length );
569
+ TF_SetAttrShapeList (
570
+ handle , new BytePointer (name ), shapesPointers , new IntPointer (numDims ), numDims .length );
557
571
}
558
572
}
559
573
560
- private static void setAttrStringList (TF_OperationDescription handle , String name , byte [][] value ) {
574
+ private static void setAttrStringList (
575
+ TF_OperationDescription handle , String name , byte [][] value ) {
561
576
requireHandle (handle );
562
577
563
578
try (PointerScope scope = new PointerScope ()) {
@@ -572,23 +587,29 @@ private static void setAttrStringList(TF_OperationDescription handle, String nam
572
587
}
573
588
}
574
589
575
- private static void setAttrFunctionName (TF_OperationDescription opHandle , String attrName , String functionName ) {
590
+ private static void setAttrFunctionName (
591
+ TF_OperationDescription opHandle , String attrName , String functionName ) {
576
592
requireHandle (opHandle );
577
593
try (PointerScope scope = new PointerScope ()) {
578
594
TF_SetAttrFuncName (opHandle , attrName , functionName , functionName .length ());
579
595
}
580
596
}
581
597
582
- private static void setAttrFunctionList (TF_OperationDescription opHandle , String attrName ,
583
- List <String > functionNames ) {
598
+ private static void setAttrFunctionList (
599
+ TF_OperationDescription opHandle , String attrName , List <String > functionNames ) {
584
600
requireHandle (opHandle );
585
601
try (PointerScope scope = new PointerScope ()) {
586
602
TF_Status status = TF_Status .newStatus ();
587
- AttrValue value = AttrValue .newBuilder ().setList (ListValue .newBuilder ().addAllFunc (
588
- functionNames .stream ()
589
- .map (x -> NameAttrList .newBuilder ().setName (x ).build ())
590
- .collect (Collectors .toList ())
591
- ).build ()).build ();
603
+ AttrValue value =
604
+ AttrValue .newBuilder ()
605
+ .setList (
606
+ ListValue .newBuilder ()
607
+ .addAllFunc (
608
+ functionNames .stream ()
609
+ .map (x -> NameAttrList .newBuilder ().setName (x ).build ())
610
+ .collect (Collectors .toList ()))
611
+ .build ())
612
+ .build ();
592
613
byte [] bytes = value .toByteArray ();
593
614
TF_SetAttrValueProto (opHandle , attrName , new BytePointer (bytes ), bytes .length , status );
594
615
status .throwExceptionIfNotOK ();
0 commit comments