Skip to content

[Quick Fix] Update op generation #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8044,6 +8044,128 @@ public While whileOp(Iterable<Operand<?>> input, ConcreteFunction cond, Concrete
return While.create(scope, input, cond, body, options);
}

/**
* Wraps the XLA ConvGeneralDilated operator, documented at
* https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
* .
*
* @param <W> data type for {@code output} output
* @param lhs the input tensor
* @param rhs the kernel tensor
* @param windowStrides the inter-window strides
* @param padding the padding to apply at the start and end of each input dimensions
* @param lhsDilation dilation to apply between input elements
* @param rhsDilation dilation to apply between kernel elements
* @param featureGroupCount number of feature groups for grouped convolution.
* @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto.
* @param precisionConfig a serialized xla::PrecisionConfig proto.
* @param preferredElementType The type of the tensor.
* @param <W> data type for {@code XlaConvV2} output and operands
* @param <V> data type for {@code XlaConvV2} output and operands
* @return a new instance of XlaConvV2
*/
public <W extends TType, V extends TNumber> XlaConvV2<W> xlaConvV2(Operand<? extends TType> lhs,
Operand<? extends TType> rhs, Operand<V> windowStrides, Operand<V> padding,
Operand<V> lhsDilation, Operand<V> rhsDilation, Operand<V> featureGroupCount,
String dimensionNumbers, String precisionConfig, Class<W> preferredElementType) {
return XlaConvV2.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig, preferredElementType);
}

/**
* Wraps the XLA DotGeneral operator, documented at
* https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
* .
*
* @param <V> data type for {@code output} output
* @param lhs the LHS tensor
* @param rhs the RHS tensor
* @param dimensionNumbers a serialized xla::DotDimensionNumbers proto.
* @param precisionConfig a serialized xla::PrecisionConfig proto.
* @param preferredElementType The type of the tensor.
* @param <V> data type for {@code XlaDotV2} output and operands
* @return a new instance of XlaDotV2
*/
public <V extends TType> XlaDotV2<V> xlaDotV2(Operand<? extends TType> lhs,
Operand<? extends TType> rhs, String dimensionNumbers, String precisionConfig,
Class<V> preferredElementType) {
return XlaDotV2.create(scope, lhs, rhs, dimensionNumbers, precisionConfig, preferredElementType);
}

/**
* Make a static dimension into a xla bounded dynamic dimension.
* <pre>
* The current static dimension size will become the bound and the second
* operand becomes the dynamic size of the dimension.
* </pre>
*
* @param <T> data type for {@code output} output
* @param input the input value
* @param dimIndex the dimIndex value
* @param sizeOutput the sizeOutput value
* @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands
* @return a new instance of XlaSetDynamicDimensionSize
*/
public <T extends TType> XlaSetDynamicDimensionSize<T> xlaSetDynamicDimensionSize(
Operand<T> input, Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) {
return XlaSetDynamicDimensionSize.create(scope, input, dimIndex, sizeOutput);
}

/**
* An op used by XLA SPMD partitioner to switch from automatic partitioning to
* manual partitioning. It annotates the input (full-shape, to be automatically
* partitioned) with the same sharding used by manual partitioning, and outputs a
* shard-shaped tensor to be consumed by later manually-partitioned ops. If the
* shape is not evenly partitionable, the padding region will be masked with 0s.
*
* @param <T> data type for {@code output} output
* @param input the input value
* @param manualSharding the value of the manualSharding property
* @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands
* @return a new instance of XlaSpmdFullToShardShape
*/
public <T extends TType> XlaSpmdFullToShardShape<T> xlaSpmdFullToShardShape(Operand<T> input,
String manualSharding) {
return XlaSpmdFullToShardShape.create(scope, input, manualSharding);
}

/**
* An op used by XLA SPMD partitioner to switch from manual partitioning to
* automatic partitioning. It converts the shard-shaped, manually partitioned input
* into full-shaped tensor to be partitioned automatically with the same sharding
* used by manual partitioning.
*
* @param <T> data type for {@code output} output
* @param input the input value
* @param manualSharding the value of the manualSharding property
* @param fullShape the value of the fullShape property
* @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands
* @return a new instance of XlaSpmdShardToFullShape
*/
public <T extends TType> XlaSpmdShardToFullShape<T> xlaSpmdShardToFullShape(Operand<T> input,
String manualSharding, Shape fullShape) {
return XlaSpmdShardToFullShape.create(scope, input, manualSharding, fullShape);
}

/**
* Wraps the XLA Sort operator, documented at
* https://www.tensorflow.org/performance/xla/operation_semantics#sort
* .
* <p>Sorts one or more tensors, with support for custom comparator, dimension, and
* is_stable attributes.
*
* @param inputs A list of {@code Tensor} of identical shape but possibly different types.
* @param dimension The dimension along which to sort. Must be a compile-time constant.
* @param comparator A comparator function to apply to 2*N scalars and returning a
* boolean. N is the number of sort inputs. If you want to sort in ascending
* order then the comparator should perform a less-than comparison.
* @param isStable Whether to use stable sort.
* @return a new instance of XlaVariadicSort
*/
public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension,
ConcreteFunction comparator, Boolean isStable) {
return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable);
}

/**
* Creates a zeroed tensor given its type and shape.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@
import org.tensorflow.op.xla.SelectAndScatter;
import org.tensorflow.op.xla.SelfAdjointEig;
import org.tensorflow.op.xla.Send;
import org.tensorflow.op.xla.SetDynamicDimensionSize;
import org.tensorflow.op.xla.Sharding;
import org.tensorflow.op.xla.Sort;
import org.tensorflow.op.xla.SpmdFullToShardShape;
import org.tensorflow.op.xla.SpmdShardToFullShape;
import org.tensorflow.op.xla.Svd;
import org.tensorflow.op.xla.While;
import org.tensorflow.op.xla.XlaHostCompute;
Expand All @@ -54,7 +51,6 @@
import org.tensorflow.op.xla.XlaSendToHost;
import org.tensorflow.op.xla.XlaSetBound;
import org.tensorflow.op.xla.XlaVariadicReduce;
import org.tensorflow.op.xla.XlaVariadicSort;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -109,7 +105,7 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) {
* https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
* .
*
* @param <W> data type for {@code output} output
* @param <T> data type for {@code output} output
* @param lhs the input tensor
* @param rhs the kernel tensor
* @param windowStrides the inter-window strides
Expand All @@ -119,16 +115,14 @@ public <T extends TType> ClusterOutput<T> clusterOutput(Operand<T> input) {
* @param featureGroupCount number of feature groups for grouped convolution.
* @param dimensionNumbers a serialized xla::ConvolutionDimensionNumbers proto.
* @param precisionConfig a serialized xla::PrecisionConfig proto.
* @param preferredElementType The type of the tensor.
* @param <W> data type for {@code XlaConvV2} output and operands
* @param <V> data type for {@code XlaConvV2} output and operands
* @param <T> data type for {@code XlaConv} output and operands
* @param <U> data type for {@code XlaConv} output and operands
* @return a new instance of Conv
*/
public <W extends TType, V extends TNumber> Conv<W> conv(Operand<? extends TType> lhs,
Operand<? extends TType> rhs, Operand<V> windowStrides, Operand<V> padding,
Operand<V> lhsDilation, Operand<V> rhsDilation, Operand<V> featureGroupCount,
String dimensionNumbers, String precisionConfig, Class<W> preferredElementType) {
return Conv.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig, preferredElementType);
public <T extends TType, U extends TNumber> Conv<T> conv(Operand<T> lhs, Operand<T> rhs,
Operand<U> windowStrides, Operand<U> padding, Operand<U> lhsDilation, Operand<U> rhsDilation,
Operand<U> featureGroupCount, String dimensionNumbers, String precisionConfig) {
return Conv.create(scope, lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation, featureGroupCount, dimensionNumbers, precisionConfig);
}

/**
Expand All @@ -153,18 +147,17 @@ public Dequantize dequantize(Operand<? extends TType> input, Float minRange, Flo
* https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
* .
*
* @param <V> data type for {@code output} output
* @param <T> data type for {@code output} output
* @param lhs the LHS tensor
* @param rhs the RHS tensor
* @param dimensionNumbers a serialized xla::DotDimensionNumbers proto.
* @param precisionConfig a serialized xla::PrecisionConfig proto.
* @param preferredElementType The type of the tensor.
* @param <V> data type for {@code XlaDotV2} output and operands
* @param <T> data type for {@code XlaDot} output and operands
* @return a new instance of Dot
*/
public <V extends TType> Dot<V> dot(Operand<? extends TType> lhs, Operand<? extends TType> rhs,
String dimensionNumbers, String precisionConfig, Class<V> preferredElementType) {
return Dot.create(scope, lhs, rhs, dimensionNumbers, precisionConfig, preferredElementType);
public <T extends TType> Dot<T> dot(Operand<T> lhs, Operand<T> rhs, String dimensionNumbers,
String precisionConfig) {
return Dot.create(scope, lhs, rhs, dimensionNumbers, precisionConfig);
}

/**
Expand Down Expand Up @@ -461,25 +454,6 @@ public Send send(Operand<? extends TType> tensor, String tensorName) {
return Send.create(scope, tensor, tensorName);
}

/**
* Make a static dimension into a xla bounded dynamic dimension.
* <pre>
* The current static dimension size will become the bound and the second
* operand becomes the dynamic size of the dimension.
* </pre>
*
* @param <T> data type for {@code output} output
* @param input the input value
* @param dimIndex the dimIndex value
* @param sizeOutput the sizeOutput value
* @param <T> data type for {@code XlaSetDynamicDimensionSize} output and operands
* @return a new instance of SetDynamicDimensionSize
*/
public <T extends TType> SetDynamicDimensionSize<T> setDynamicDimensionSize(Operand<T> input,
Operand<TInt32> dimIndex, Operand<TInt32> sizeOutput) {
return SetDynamicDimensionSize.create(scope, input, dimIndex, sizeOutput);
}

/**
* An op which shards the input based on the given sharding attribute.
*
Expand Down Expand Up @@ -508,42 +482,6 @@ public <T extends TType> Sort<T> sort(Operand<T> input) {
return Sort.create(scope, input);
}

/**
* An op used by XLA SPMD partitioner to switch from automatic partitioning to
* manual partitioning. It annotates the input (full-shape, to be automatically
* partitioned) with the same sharding used by manual partitioning, and outputs a
* shard-shaped tensor to be consumed by later manually-partitioned ops. If the
* shape is not evenly partitionable, the padding region will be masked with 0s.
*
* @param <T> data type for {@code output} output
* @param input the input value
* @param manualSharding the value of the manualSharding property
* @param <T> data type for {@code XlaSpmdFullToShardShape} output and operands
* @return a new instance of SpmdFullToShardShape
*/
public <T extends TType> SpmdFullToShardShape<T> spmdFullToShardShape(Operand<T> input,
String manualSharding) {
return SpmdFullToShardShape.create(scope, input, manualSharding);
}

/**
* An op used by XLA SPMD partitioner to switch from manual partitioning to
* automatic partitioning. It converts the shard-shaped, manually partitioned input
* into full-shaped tensor to be partitioned automatically with the same sharding
* used by manual partitioning.
*
* @param <T> data type for {@code output} output
* @param input the input value
* @param manualSharding the value of the manualSharding property
* @param fullShape the value of the fullShape property
* @param <T> data type for {@code XlaSpmdShardToFullShape} output and operands
* @return a new instance of SpmdShardToFullShape
*/
public <T extends TType> SpmdShardToFullShape<T> spmdShardToFullShape(Operand<T> input,
String manualSharding, Shape fullShape) {
return SpmdShardToFullShape.create(scope, input, manualSharding, fullShape);
}

/**
* Computes the eigen decomposition of a batch of self-adjoint matrices
* (Note: Only real inputs are supported).
Expand Down Expand Up @@ -686,26 +624,6 @@ public <T extends TType> XlaVariadicReduce<T> xlaVariadicReduce(Iterable<Operand
return XlaVariadicReduce.create(scope, input, initValue, dimensionsToReduce, reducer);
}

/**
* Wraps the XLA Sort operator, documented at
* https://www.tensorflow.org/performance/xla/operation_semantics#sort
* .
* <p>Sorts one or more tensors, with support for custom comparator, dimension, and
* is_stable attributes.
*
* @param inputs A list of {@code Tensor} of identical shape but possibly different types.
* @param dimension The dimension along which to sort. Must be a compile-time constant.
* @param comparator A comparator function to apply to 2*N scalars and returning a
* boolean. N is the number of sort inputs. If you want to sort in ascending
* order then the comparator should perform a less-than comparison.
* @param isStable Whether to use stable sort.
* @return a new instance of XlaVariadicSort
*/
public XlaVariadicSort xlaVariadicSort(Iterable<Operand<?>> inputs, Operand<TInt32> dimension,
ConcreteFunction comparator, Boolean isStable) {
return XlaVariadicSort.create(scope, inputs, dimension, comparator, isStable);
}

/**
* Get the parent {@link Ops} object.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ private AudioSpectrogram(Operation operation) {
)
public static AudioSpectrogram create(Scope scope, Operand<TFloat32> input, Long windowSize,
Long stride, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("AudioSpectrogram"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "AudioSpectrogram");
opBuilder.addInput(input.asOutput());
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("window_size", windowSize);
opBuilder.setAttr("stride", stride);
if (options != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ private DecodeWav(Operation operation) {
describeByClass = true
)
public static DecodeWav create(Scope scope, Operand<TString> contents, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("DecodeWav"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "DecodeWav");
opBuilder.addInput(contents.asOutput());
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.desiredChannels != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ private EncodeWav(Operation operation) {
describeByClass = true
)
public static EncodeWav create(Scope scope, Operand<TFloat32> audio, Operand<TInt32> sampleRate) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("EncodeWav"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "EncodeWav");
opBuilder.addInput(audio.asOutput());
opBuilder.addInput(sampleRate.asOutput());
opBuilder = scope.apply(opBuilder);
return new EncodeWav(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ private Mfcc(Operation operation) {
)
public static Mfcc create(Scope scope, Operand<TFloat32> spectrogram, Operand<TInt32> sampleRate,
Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Mfcc"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "Mfcc");
opBuilder.addInput(spectrogram.asOutput());
opBuilder.addInput(sampleRate.asOutput());
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.upperFrequencyLimit != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ private BitwiseAnd(Operation operation) {
describeByClass = true
)
public static <T extends TNumber> BitwiseAnd<T> create(Scope scope, Operand<T> x, Operand<T> y) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("BitwiseAnd"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "BitwiseAnd");
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.apply(opBuilder);
return new BitwiseAnd<>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ private BitwiseOr(Operation operation) {
describeByClass = true
)
public static <T extends TNumber> BitwiseOr<T> create(Scope scope, Operand<T> x, Operand<T> y) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("BitwiseOr"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "BitwiseOr");
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.apply(opBuilder);
return new BitwiseOr<>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ private BitwiseXor(Operation operation) {
describeByClass = true
)
public static <T extends TNumber> BitwiseXor<T> create(Scope scope, Operand<T> x, Operand<T> y) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("BitwiseXor"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "BitwiseXor");
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.apply(opBuilder);
return new BitwiseXor<>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ private Invert(Operation operation) {
describeByClass = true
)
public static <T extends TNumber> Invert<T> create(Scope scope, Operand<T> x) {
OperationBuilder opBuilder = scope.env().opBuilder(OP_NAME, scope.makeOpName("Invert"));
OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "Invert");
opBuilder.addInput(x.asOutput());
opBuilder = scope.apply(opBuilder);
return new Invert<>(opBuilder.build());
}

Expand Down
Loading