42
42
import org .tensorflow .op .xla .SelectAndScatter ;
43
43
import org .tensorflow .op .xla .SelfAdjointEig ;
44
44
import org .tensorflow .op .xla .Send ;
45
+ import org .tensorflow .op .xla .SetDynamicDimensionSize ;
45
46
import org .tensorflow .op .xla .Sharding ;
46
47
import org .tensorflow .op .xla .Sort ;
48
+ import org .tensorflow .op .xla .SpmdFullToShardShape ;
49
+ import org .tensorflow .op .xla .SpmdShardToFullShape ;
47
50
import org .tensorflow .op .xla .Svd ;
48
51
import org .tensorflow .op .xla .While ;
49
52
import org .tensorflow .op .xla .XlaHostCompute ;
52
55
import org .tensorflow .op .xla .XlaSendToHost ;
53
56
import org .tensorflow .op .xla .XlaSetBound ;
54
57
import org .tensorflow .op .xla .XlaVariadicReduce ;
58
+ import org .tensorflow .op .xla .XlaVariadicSort ;
55
59
import org .tensorflow .types .TInt32 ;
56
60
import org .tensorflow .types .family .TNumber ;
57
61
import org .tensorflow .types .family .TType ;
@@ -106,7 +110,7 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) {
106
110
* https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
107
111
* .
108
112
*
109
- * @param <T > data type for {@code output} output
113
+ * @param <W > data type for {@code output} output
110
114
* @param lhs the input tensor
111
115
* @param rhs the kernel tensor
112
116
* @param windowStrides the inter-window strides
@@ -116,14 +120,16 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) {
116
120
* @param featureGroupCount number of feature groups for grouped convolution.
117
121
* @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto.
118
122
* @param precisionConfig a serialized xla::PrecisionConfig proto.
119
- * @param <T> data type for {@code XlaConv} output and operands
120
- * @param <U> data type for {@code XlaConv} output and operands
123
+ * @param preferredElementType The type of the tensor.
124
+ * @param <W> data type for {@code XlaConvV2} output and operands
125
+ * @param <V> data type for {@code XlaConvV2} output and operands
121
126
* @return a new instance of Conv
122
127
*/
123
- public <T extends TType , U extends TNumber > Conv <T > conv (Operand <T > lhs , Operand <T > rhs ,
124
- Operand <U > windowStrides , Operand <U > padding , Operand <U > lhsDilation , Operand <U > rhsDilation ,
125
- Operand <U > featureGroupCount , String dimensionNumbers , String precisionConfig ) {
126
- return Conv .create (scope , lhs , rhs , windowStrides , padding , lhsDilation , rhsDilation , featureGroupCount , dimensionNumbers , precisionConfig );
128
+ public <W extends TType , V extends TNumber > Conv <W > conv (Operand <? extends TType > lhs ,
129
+ Operand <? extends TType > rhs , Operand <V > windowStrides , Operand <V > padding ,
130
+ Operand <V > lhsDilation , Operand <V > rhsDilation , Operand <V > featureGroupCount ,
131
+ String dimensionNumbers , String precisionConfig , Class <W > preferredElementType ) {
132
+ return Conv .create (scope , lhs , rhs , windowStrides , padding , lhsDilation , rhsDilation , featureGroupCount , dimensionNumbers , precisionConfig , preferredElementType );
127
133
}
128
134
129
135
/**
@@ -148,17 +154,18 @@ public Dequantize dequantize(Operand<? extends TType> input, Float minRange, Flo
148
154
* https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
149
155
* .
150
156
*
151
- * @param <T > data type for {@code output} output
157
+ * @param <V > data type for {@code output} output
152
158
* @param lhs the LHS tensor
153
159
* @param rhs the RHS tensor
154
160
* @param dimensionNumbers a serialized xla::DotDimensionNumbers proto.
155
161
* @param precisionConfig a serialized xla::PrecisionConfig proto.
156
- * @param <T> data type for {@code XlaDot} output and operands
162
+ * @param preferredElementType The type of the tensor.
163
+ * @param <V> data type for {@code XlaDotV2} output and operands
157
164
* @return a new instance of Dot
158
165
*/
159
- public <T extends TType > Dot <T > dot (Operand <T > lhs , Operand <T > rhs , String dimensionNumbers ,
160
- String precisionConfig ) {
161
- return Dot .create (scope , lhs , rhs , dimensionNumbers , precisionConfig );
166
+ public <V extends TType > Dot <V > dot (Operand <? extends TType > lhs , Operand <? extends TType > rhs ,
167
+ String dimensionNumbers , String precisionConfig , Class < V > preferredElementType ) {
168
+ return Dot .create (scope , lhs , rhs , dimensionNumbers , precisionConfig , preferredElementType );
162
169
}
163
170
164
171
/**
@@ -473,6 +480,25 @@ public Send send(Operand<? extends TType> tensor, String tensorName) {
473
480
return Send .create (scope , tensor , tensorName );
474
481
}
475
482
483
+ /**
484
+ * Make a static dimension into a xla bounded dynamic dimension.
485
+ * <pre>
486
+ * The current static dimension size will become the bound and the second
487
+ * operand becomes the dynamic size of the dimension.
488
+ * </pre>
489
+ *
490
+ * @param <T> data type for {@code output} output
491
+ * @param input the input value
492
+ * @param dimIndex the dimIndex value
493
+ * @param sizeOutput the sizeOutput value
494
+ * @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands
495
+ * @return a new instance of SetDynamicDimensionSize
496
+ */
497
+ public <T extends TType > SetDynamicDimensionSize <T > setDynamicDimensionSize (Operand <T > input ,
498
+ Operand <TInt32 > dimIndex , Operand <TInt32 > sizeOutput ) {
499
+ return SetDynamicDimensionSize .create (scope , input , dimIndex , sizeOutput );
500
+ }
501
+
476
502
/**
477
503
* An op which shards the input based on the given sharding attribute.
478
504
*
@@ -501,6 +527,42 @@ public <T extends TType> Sort<T> sort(Operand<T> input) {
501
527
return Sort .create (scope , input );
502
528
}
503
529
530
+ /**
531
+ * An op used by XLA SPMD partitioner to switch from automatic partitioning to
532
+ * manual partitioning. It annotates the input (full-shape, to be automatically
533
+ * partitioned) with the same sharding used by manual partitioning, and outputs a
534
+ * shard-shaped tensor to be consumed by later manually-partitioned ops. If the
535
+ * shape is not evenly partitionable, the padding region will be masked with 0s.
536
+ *
537
+ * @param <T> data type for {@code output} output
538
+ * @param input the input value
539
+ * @param manualSharding the value of the manualSharding property
540
+ * @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands
541
+ * @return a new instance of SpmdFullToShardShape
542
+ */
543
+ public <T extends TType > SpmdFullToShardShape <T > spmdFullToShardShape (Operand <T > input ,
544
+ String manualSharding ) {
545
+ return SpmdFullToShardShape .create (scope , input , manualSharding );
546
+ }
547
+
548
+ /**
549
+ * An op used by XLA SPMD partitioner to switch from manual partitioning to
550
+ * automatic partitioning. It converts the shard-shaped, manually partitioned input
551
+ * into full-shaped tensor to be partitioned automatically with the same sharding
552
+ * used by manual partitioning.
553
+ *
554
+ * @param <T> data type for {@code output} output
555
+ * @param input the input value
556
+ * @param manualSharding the value of the manualSharding property
557
+ * @param fullShape the value of the fullShape property
558
+ * @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands
559
+ * @return a new instance of SpmdShardToFullShape
560
+ */
561
+ public <T extends TType > SpmdShardToFullShape <T > spmdShardToFullShape (Operand <T > input ,
562
+ String manualSharding , Shape fullShape ) {
563
+ return SpmdShardToFullShape .create (scope , input , manualSharding , fullShape );
564
+ }
565
+
504
566
/**
505
567
* Computes the eigen decomposition of a batch of self-adjoint matrices
506
568
* (Note: Only real inputs are supported).
@@ -643,6 +705,26 @@ public <T extends TType> XlaVariadicReduce<T> xlaVariadicReduce(Iterable<Operand
643
705
return XlaVariadicReduce .create (scope , input , initValue , dimensionsToReduce , reducer );
644
706
}
645
707
708
+ /**
709
+ * Wraps the XLA Sort operator, documented at
710
+ * https://www.tensorflow.org/performance/xla/operation_semantics#sort
711
+ * .
712
+ * <p>Sorts one or more tensors, with support for custom comparator, dimension, and
713
+ * is_stable attributes.
714
+ *
715
+ * @param inputs A list of {@code Tensor} of identical shape but possibly different types.
716
+ * @param dimension The dimension along which to sort. Must be a compile-time constant.
717
+ * @param comparator A comparator function to apply to 2*N scalars and returning a
718
+ * boolean. N is the number of sort inputs. If you want to sort in ascending
719
+ * order then the comparator should perform a less-than comparison.
720
+ * @param isStable Whether to use stable sort.
721
+ * @return a new instance of XlaVariadicSort
722
+ */
723
+ public XlaVariadicSort xlaVariadicSort (Iterable <Operand <?>> inputs , Operand <TInt32 > dimension ,
724
+ ConcreteFunction comparator , Boolean isStable ) {
725
+ return XlaVariadicSort .create (scope , inputs , dimension , comparator , isStable );
726
+ }
727
+
646
728
/**
647
729
* Get the parent {@link Ops} object.
648
730
*/
0 commit comments