From 74a51ca4de94f5cd3041b2976227a4b507cb50da Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Feb 2021 22:17:41 -0800 Subject: [PATCH] Number constant ops Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 27 ++++++ .../java/org/tensorflow/op/core/Constant.java | 89 ++++++++++++++++--- .../org/tensorflow/op/core/ConstantTest.java | 49 +++++++--- 3 files changed, 143 insertions(+), 22 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..233139a9562 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1831,6 +1831,19 @@ public Constant constant(Shape shape, IntDataBuffer data) { return Constant.tensorOf(scope, shape, data); } + /** + * Creates a scalar of {@code type}, with the value of {@code number}. + * {@code number} may be truncated if it does not fit in the target type. + * + * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) + * @param number the value of the tensor + * @return a constant of the passed type + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown. + */ + public Constant constant(Class type, Number number) { + return Constant.tensorOf(scope, type, number); + } + /** * Create a {@link TString} constant with data from the given buffer, using the given encoding. * @@ -1876,6 +1889,20 @@ public Constant constantOf(T tensor) { return Constant.create(scope, tensor); } + /** + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. + * {@code number} may be truncated if it does not fit in the target type. + * + * @param toMatch the operand providing the target type + * @param number the value of the tensor + * @return a constant with the same type as {@code toMatch} + * @see Ops#constant(Class, Number) + * @throws IllegalArgumentException if the type is unknown (which should be impossible). + */ + public Constant constantOfSameType(Operand toMatch, Number number) { + return Constant.tensorOfSameType(scope, toMatch, number); + } + /** * This op consumes a lock created by `MutexLock`. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 918f9083923..497ee5f2d46 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -20,18 +20,6 @@ import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.Tensor; -import org.tensorflow.op.RawOp; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.BooleanDataBuffer; -import org.tensorflow.ndarray.buffer.ByteDataBuffer; -import org.tensorflow.ndarray.buffer.DataBuffer; -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.IntDataBuffer; -import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -40,14 +28,30 @@ import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.op.Ops; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -1277,6 +1281,67 @@ public static Constant tensorOf(Scope scope, Shape shape) { return vectorOf(scope, shape.asArray()); } + /** + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not + * fit in the target type. + * + * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) + * @param number the value of the tensor + * @return a constant of the passed type + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or + * unknown. + */ + @SuppressWarnings("unchecked") + @Endpoint + public static Constant tensorOf(Scope scope, Class type, Number number) { + if (type.equals(TBfloat16.class)) { + try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat64.class)) { + try (TFloat64 tensor = TFloat64.scalarOf(number.doubleValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat32.class)) { + try (TFloat32 tensor = TFloat32.scalarOf(number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat16.class)) { + try (TFloat16 tensor = TFloat16.scalarOf(number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt64.class)) { + try (TInt64 tensor = TInt64.scalarOf(number.longValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt32.class)) { + try (TInt32 tensor = TInt32.scalarOf(number.intValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TUint8.class)) { + try (TUint8 tensor = TUint8.scalarOf(number.byteValue())) { + return (Constant) create(scope, tensor); + } + } else { + throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); + } + } + + /** + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. + * + * @param toMatch the operand providing the target type + * @param number the value of the tensor + * @return a constant with the same type as {@code toMatch} + * @throws IllegalArgumentException if the type is unknown (which should be impossible). + * @see Ops#constant(Class, Number) + */ + @Endpoint(name = "constantOfSameType") + public static Constant tensorOfSameType(Scope scope, Operand toMatch, Number number) { + return tensorOf(scope, toMatch.type(), number); + } + /** * Create a constant by making an immutable copy of {@code tensor}. {@code tensor} may be closed afterwards without * issue. diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 5dd6903d913..6df73261867 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -18,15 +18,19 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.io.IOException; - import org.junit.jupiter.api.Test; import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.op.Ops; -import org.tensorflow.op.Scope; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.buffer.DataBuffers; @@ -34,19 +38,20 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.LongNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TNumber; public class ConstantTest { + private static final float EPSILON = 1e-7f; @Test @@ -56,7 +61,7 @@ public void createInts() { IntNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); @@ -164,4 +169,28 @@ public void createFromTensorsInEagerMode() throws IOException { assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor()); } } + + private static void testCreateFromNumber(Ops tf, Class type) { + Operand constant = tf.constant(type, 10); + assertEquals(type, constant.type()); + + try (TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor()) { + assertEquals(10.0, t.getDouble()); + } + } + + @Test + public void createFromNumber() { + try (EagerSession s = EagerSession.create()) { + Ops tf = Ops.create(s); + + testCreateFromNumber(tf, TBfloat16.class); + testCreateFromNumber(tf, TFloat64.class); + testCreateFromNumber(tf, TFloat32.class); + testCreateFromNumber(tf, TFloat16.class); + testCreateFromNumber(tf, TInt64.class); + testCreateFromNumber(tf, TInt32.class); + testCreateFromNumber(tf, TUint8.class); + } + } }