Skip to content

Commit 8bac25e

Browse files
committed
Fix generation
Signed-off-by: Ryan Nett <[email protected]>
1 parent bd12f37 commit 8bac25e

File tree

80 files changed

+373
-1055
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+373
-1055
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 6 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -367,20 +367,20 @@ public final class Ops {
367367

368368
public final SparseOps sparse;
369369

370-
public final TpuOps tpu;
371-
372370
public final BitwiseOps bitwise;
373371

372+
public final TpuOps tpu;
373+
374374
public final MathOps math;
375375

376376
public final AudioOps audio;
377377

378378
public final SignalOps signal;
379379

380-
public final QuantizationOps quantization;
381-
382380
public final TrainOps train;
383381

382+
public final QuantizationOps quantization;
383+
384384
private final Scope scope;
385385

386386
private Ops(Scope scope) {
@@ -398,13 +398,13 @@ private Ops(Scope scope) {
398398
random = new RandomOps(this);
399399
strings = new StringsOps(this);
400400
sparse = new SparseOps(this);
401-
tpu = new TpuOps(this);
402401
bitwise = new BitwiseOps(this);
402+
tpu = new TpuOps(this);
403403
math = new MathOps(this);
404404
audio = new AudioOps(this);
405405
signal = new SignalOps(this);
406-
quantization = new QuantizationOps(this);
407406
train = new TrainOps(this);
407+
quantization = new QuantizationOps(this);
408408
}
409409

410410
/**
@@ -8043,128 +8043,6 @@ public While whileOp(Iterable<Operand<?>> input, ConcreteFunction cond, Concrete
80438043
return While.create(scope, input, cond, body, options);
80448044
}
80458045

8046-
/**
8047-
* Wraps the XLA ConvGeneralDilated operator, documented at
8048-
* https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
8049-
* .
8050-
*
8051-
* @param <W> data type for {@code output} output
8052-
* @param lhs the input tensor
8053-
* @param rhs the kernel tensor
8054-
* @param windowStrides the inter-window strides
8055-
* @param padding the padding to apply at the start and end of each input dimensions
8056-
* @param lhsDilation dilation to apply between input elements
8057-
* @param rhsDilation dilation to apply between kernel elements
8058-
* @param featureGroupCount number of feature groups for grouped convolution.
8059-
* @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto.
8060-
* @param precisionConfig a serialized xla::PrecisionConfig proto.
8061-
* @param preferredElementType The type of the tensor.
8062-
* @param <W> data type for {@code XlaConvV2} output and operands
8063-
* @param <V> data type for {@code XlaConvV2} output and operands
8064-
* @return a new instance of XlaConvV2
8065-
*/
8066-
public <W extends TType, V extends TNumber> XlaConvV2<W> xlaConvV2(Operand<? extends TType> lhs,
8067-
Operand<? extends TType> rhs, Operand<V> windowStrides, Operand<V> padding,
8068-
Operand<V> lhsDilation, Operand<V> rhsDilation, Operand<V> featureGroupCount,
8069-
String dimensionNumbers, String precisionConfig, Class<W> preferredElementType) {
8070-
return XlaConvV2.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig, preferredElementType);
8071-
}
8072-
8073-
/**
8074-
* Wraps the XLA DotGeneral operator, documented at
8075-
* https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
8076-
* .
8077-
*
8078-
* @param <V> data type for {@code output} output
8079-
* @param lhs the LHS tensor
8080-
* @param rhs the RHS tensor
8081-
* @param dimensionNumbers a serialized xla::DotDimensionNumbers proto.
8082-
* @param precisionConfig a serialized xla::PrecisionConfig proto.
8083-
* @param preferredElementType The type of the tensor.
8084-
* @param <V> data type for {@code XlaDotV2} output and operands
8085-
* @return a new instance of XlaDotV2
8086-
*/
8087-
public <V extends TType> XlaDotV2<V> xlaDotV2(Operand<? extends TType> lhs,
8088-
Operand<? extends TType> rhs, String dimensionNumbers, String precisionConfig,
8089-
Class<V> preferredElementType) {
8090-
return XlaDotV2.create(scope, lhs, rhs, dimensionNumbers, precisionConfig, preferredElementType);
8091-
}
8092-
8093-
/**
8094-
* Make a static dimension into a xla bounded dynamic dimension.
8095-
* <pre>
8096-
* The current static dimension size will become the bound and the second
8097-
* operand becomes the dynamic size of the dimension.
8098-
* </pre>
8099-
*
8100-
* @param <T> data type for {@code output} output
8101-
* @param input the input value
8102-
* @param dimIndex the dimIndex value
8103-
* @param sizeOutput the sizeOutput value
8104-
* @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands
8105-
* @return a new instance of XlaSetDynamicDimensionSize
8106-
*/
8107-
public <T extends TType> XlaSetDynamicDimensionSize<T> xlaSetDynamicDimensionSize(
8108-
Operand<T> input, Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) {
8109-
return XlaSetDynamicDimensionSize.create(scope, input, dimIndex, sizeOutput);
8110-
}
8111-
8112-
/**
8113-
* An op used by XLA SPMD partitioner to switch from automatic partitioning to
8114-
* manual partitioning. It annotates the input (full-shape, to be automatically
8115-
* partitioned) with the same sharding used by manual partitioning, and outputs a
8116-
* shard-shaped tensor to be consumed by later manually-partitioned ops. If the
8117-
* shape is not evenly partitionable, the padding region will be masked with 0s.
8118-
*
8119-
* @param <T> data type for {@code output} output
8120-
* @param input the input value
8121-
* @param manualSharding the value of the manualSharding property
8122-
* @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands
8123-
* @return a new instance of XlaSpmdFullToShardShape
8124-
*/
8125-
public <T extends TType> XlaSpmdFullToShardShape<T> xlaSpmdFullToShardShape(Operand<T> input,
8126-
String manualSharding) {
8127-
return XlaSpmdFullToShardShape.create(scope, input, manualSharding);
8128-
}
8129-
8130-
/**
8131-
* An op used by XLA SPMD partitioner to switch from manual partitioning to
8132-
* automatic partitioning. It converts the shard-shaped, manually partitioned input
8133-
* into full-shaped tensor to be partitioned automatically with the same sharding
8134-
* used by manual partitioning.
8135-
*
8136-
* @param <T> data type for {@code output} output
8137-
* @param input the input value
8138-
* @param manualSharding the value of the manualSharding property
8139-
* @param fullShape the value of the fullShape property
8140-
* @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands
8141-
* @return a new instance of XlaSpmdShardToFullShape
8142-
*/
8143-
public <T extends TType> XlaSpmdShardToFullShape<T> xlaSpmdShardToFullShape(Operand<T> input,
8144-
String manualSharding, Shape fullShape) {
8145-
return XlaSpmdShardToFullShape.create(scope, input, manualSharding, fullShape);
8146-
}
8147-
8148-
/**
8149-
* Wraps the XLA Sort operator, documented at
8150-
* https://www.tensorflow.org/performance/xla/operation_semantics#sort
8151-
* .
8152-
* <p>Sorts one or more tensors, with support for custom comparator, dimension, and
8153-
* is_stable attributes.
8154-
*
8155-
* @param inputs A list of {@code Tensor} of identical shape but possibly different types.
8156-
* @param dimension The dimension along which to sort. Must be a compile-time constant.
8157-
* @param comparator A comparator function to apply to 2*N scalars and returning a
8158-
* boolean. N is the number of sort inputs. If you want to sort in ascending
8159-
* order then the comparator should perform a less-than comparison.
8160-
* @param isStable Whether to use stable sort.
8161-
* @return a new instance of XlaVariadicSort
8162-
*/
8163-
public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension,
8164-
ConcreteFunction comparator, Boolean isStable) {
8165-
return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable);
8166-
}
8167-
81688046
/**
81698047
* Creates a zeroed tensor given its type and shape.
81708048
*

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@
4242
import org.tensorflow.op.xla.SelectAndScatter;
4343
import org.tensorflow.op.xla.SelfAdjointEig;
4444
import org.tensorflow.op.xla.Send;
45+
import org.tensorflow.op.xla.SetDynamicDimensionSize;
4546
import org.tensorflow.op.xla.Sharding;
4647
import org.tensorflow.op.xla.Sort;
48+
import org.tensorflow.op.xla.SpmdFullToShardShape;
49+
import org.tensorflow.op.xla.SpmdShardToFullShape;
4750
import org.tensorflow.op.xla.Svd;
4851
import org.tensorflow.op.xla.While;
4952
import org.tensorflow.op.xla.XlaHostCompute;
@@ -52,6 +55,7 @@
5255
import org.tensorflow.op.xla.XlaSendToHost;
5356
import org.tensorflow.op.xla.XlaSetBound;
5457
import org.tensorflow.op.xla.XlaVariadicReduce;
58+
import org.tensorflow.op.xla.XlaVariadicSort;
5559
import org.tensorflow.types.TInt32;
5660
import org.tensorflow.types.family.TNumber;
5761
import org.tensorflow.types.family.TType;
@@ -106,7 +110,7 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) {
106110
* https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
107111
* .
108112
*
109-
* @param <T> data type for {@code output} output
113+
* @param <W> data type for {@code output} output
110114
* @param lhs the input tensor
111115
* @param rhs the kernel tensor
112116
* @param windowStrides the inter-window strides
@@ -116,14 +120,16 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) {
116120
* @param featureGroupCount number of feature groups for grouped convolution.
117121
* @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto.
118122
* @param precisionConfig a serialized xla::PrecisionConfig proto.
119-
* @param <T> data type for {@code XlaConv} output and operands
120-
* @param <U> data type for {@code XlaConv} output and operands
123+
* @param preferredElementType The type of the tensor.
124+
* @param <W> data type for {@code XlaConvV2} output and operands
125+
* @param <V> data type for {@code XlaConvV2} output and operands
121126
* @return a new instance of Conv
122127
*/
123-
public <T extends TType, U extends TNumber> Conv<T> conv(Operand<T> lhs, Operand<T> rhs,
124-
Operand<U> windowStrides, Operand<U> padding, Operand<U> lhsDilation, Operand<U> rhsDilation,
125-
Operand<U> featureGroupCount, String dimensionNumbers, String precisionConfig) {
126-
return Conv.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig);
128+
public <W extends TType, V extends TNumber> Conv<W> conv(Operand<? extends TType> lhs,
129+
Operand<? extends TType> rhs, Operand<V> windowStrides, Operand<V> padding,
130+
Operand<V> lhsDilation, Operand<V> rhsDilation, Operand<V> featureGroupCount,
131+
String dimensionNumbers, String precisionConfig, Class<W> preferredElementType) {
132+
return Conv.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig, preferredElementType);
127133
}
128134

129135
/**
@@ -148,17 +154,18 @@ public Dequantize dequantize(Operand<? extends TType> input, Float minRange, Flo
148154
* https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
149155
* .
150156
*
151-
* @param <T> data type for {@code output} output
157+
* @param <V> data type for {@code output} output
152158
* @param lhs the LHS tensor
153159
* @param rhs the RHS tensor
154160
* @param dimensionNumbers a serialized xla::DotDimensionNumbers proto.
155161
* @param precisionConfig a serialized xla::PrecisionConfig proto.
156-
* @param <T> data type for {@code XlaDot} output and operands
162+
* @param preferredElementType The type of the tensor.
163+
* @param <V> data type for {@code XlaDotV2} output and operands
157164
* @return a new instance of Dot
158165
*/
159-
public <T extends TType> Dot<T> dot(Operand<T> lhs, Operand<T> rhs, String dimensionNumbers,
160-
String precisionConfig) {
161-
return Dot.create(scope, lhs, rhs, dimensionNumbers, precisionConfig);
166+
public <V extends TType> Dot<V> dot(Operand<? extends TType> lhs, Operand<? extends TType> rhs,
167+
String dimensionNumbers, String precisionConfig, Class<V> preferredElementType) {
168+
return Dot.create(scope, lhs, rhs, dimensionNumbers, precisionConfig, preferredElementType);
162169
}
163170

164171
/**
@@ -473,6 +480,25 @@ public Send send(Operand<? extends TType> tensor, String tensorName) {
473480
return Send.create(scope, tensor, tensorName);
474481
}
475482

483+
/**
484+
* Make a static dimension into a xla bounded dynamic dimension.
485+
* <pre>
486+
* The current static dimension size will become the bound and the second
487+
* operand becomes the dynamic size of the dimension.
488+
* </pre>
489+
*
490+
* @param <T> data type for {@code output} output
491+
* @param input the input value
492+
* @param dimIndex the dimIndex value
493+
* @param sizeOutput the sizeOutput value
494+
* @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands
495+
* @return a new instance of SetDynamicDimensionSize
496+
*/
497+
public <T extends TType> SetDynamicDimensionSize<T> setDynamicDimensionSize(Operand<T> input,
498+
Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) {
499+
return SetDynamicDimensionSize.create(scope, input, dimIndex, sizeOutput);
500+
}
501+
476502
/**
477503
* An op which shards the input based on the given sharding attribute.
478504
*
@@ -501,6 +527,42 @@ public <T extends TType> Sort<T> sort(Operand<T> input) {
501527
return Sort.create(scope, input);
502528
}
503529

530+
/**
531+
* An op used by XLA SPMD partitioner to switch from automatic partitioning to
532+
* manual partitioning. It annotates the input (full-shape, to be automatically
533+
* partitioned) with the same sharding used by manual partitioning, and outputs a
534+
* shard-shaped tensor to be consumed by later manually-partitioned ops. If the
535+
* shape is not evenly partitionable, the padding region will be masked with 0s.
536+
*
537+
* @param <T> data type for {@code output} output
538+
* @param input the input value
539+
* @param manualSharding the value of the manualSharding property
540+
* @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands
541+
* @return a new instance of SpmdFullToShardShape
542+
*/
543+
public <T extends TType> SpmdFullToShardShape<T> spmdFullToShardShape(Operand<T> input,
544+
String manualSharding) {
545+
return SpmdFullToShardShape.create(scope, input, manualSharding);
546+
}
547+
548+
/**
549+
* An op used by XLA SPMD partitioner to switch from manual partitioning to
550+
* automatic partitioning. It converts the shard-shaped, manually partitioned input
551+
* into full-shaped tensor to be partitioned automatically with the same sharding
552+
* used by manual partitioning.
553+
*
554+
* @param <T> data type for {@code output} output
555+
* @param input the input value
556+
* @param manualSharding the value of the manualSharding property
557+
* @param fullShape the value of the fullShape property
558+
* @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands
559+
* @return a new instance of SpmdShardToFullShape
560+
*/
561+
public <T extends TType> SpmdShardToFullShape<T> spmdShardToFullShape(Operand<T> input,
562+
String manualSharding, Shape fullShape) {
563+
return SpmdShardToFullShape.create(scope, input, manualSharding, fullShape);
564+
}
565+
504566
/**
505567
* Computes the eigen decomposition of a batch of self-adjoint matrices
506568
* (Note: Only real inputs are supported).
@@ -643,6 +705,26 @@ public <T extends TType> XlaVariadicReduce<T> xlaVariadicReduce(Iterable<Operand
643705
return XlaVariadicReduce.create(scope, input, initValue, dimensionsToReduce, reducer);
644706
}
645707

708+
/**
709+
* Wraps the XLA Sort operator, documented at
710+
* https://www.tensorflow.org/performance/xla/operation_semantics#sort
711+
* .
712+
* <p>Sorts one or more tensors, with support for custom comparator, dimension, and
713+
* is_stable attributes.
714+
*
715+
* @param inputs A list of {@code Tensor} of identical shape but possibly different types.
716+
* @param dimension The dimension along which to sort. Must be a compile-time constant.
717+
* @param comparator A comparator function to apply to 2*N scalars and returning a
718+
* boolean. N is the number of sort inputs. If you want to sort in ascending
719+
* order then the comparator should perform a less-than comparison.
720+
* @param isStable Whether to use stable sort.
721+
* @return a new instance of XlaVariadicSort
722+
*/
723+
public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension,
724+
ConcreteFunction comparator, Boolean isStable) {
725+
return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable);
726+
}
727+
646728
/**
647729
* Get the parent {@link Ops} object.
648730
*/

0 commit comments

Comments
 (0)