From ec05de48927fe582d46db2483e692e570a34795d Mon Sep 17 00:00:00 2001 From: Jordan Bentley Date: Tue, 5 Jan 2021 13:38:20 -0500 Subject: [PATCH 1/3] Expose input and output types from signature from method other than toString --- .../main/java/org/tensorflow/Signature.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index ea32d1fff13..e7810a3f9dd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -15,9 +15,11 @@ */ package org.tensorflow; +import java.util.HashMap; import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.framework.TensorShapeProto; @@ -174,6 +176,28 @@ public String toString() { return strBuilder.toString(); } + /** + * Returns the names of the inputs in this signature mapped to their expected data type + */ + public Map getInputs() { + Map dataTypeMap = new HashMap<>(); + signatureDef.getInputsMap().forEach((a,b) -> { + dataTypeMap.put(a, b.getDtype()); + }); + return dataTypeMap; + } + + /** + * Returns the names of the outputs in this signature mapped to their expected data type + */ + public Map getOutputs() { + Map dataTypeMap = new HashMap<>(); + signatureDef.getOutputsMap().forEach((a,b) -> { + dataTypeMap.put(a, b.getDtype()); + }); + return dataTypeMap; + } + Signature(String key, SignatureDef signatureDef) { this.key = key; this.signatureDef = signatureDef; From cc999d48b2c07c7cb8dcd62c2ceba74fb2ff8f50 Mon Sep 17 00:00:00 2001 From: Jordan Bentley Date: Thu, 28 Jan 2021 09:29:23 -0500 Subject: [PATCH 2/3] Expose input and output types from signature with shape --- .../main/java/org/tensorflow/Signature.java | 39 ++++++++++++------- .../java/org/tensorflow/SignatureTest.java | 32 +++++++++++++-- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index e7810a3f9dd..cf22474b55c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -34,6 +34,16 @@ public class Signature { /** The default signature key, when not provided */ public static final String DEFAULT_KEY = "serving_default"; + public static class TensorDescription { + public final DataType dataType; + public final long[] shape; + + public TensorDescription(DataType dataType, long[] shape) { + this.dataType = dataType; + this.shape = shape; + } + } + /** * Builds a new function signature. */ @@ -176,26 +186,27 @@ public String toString() { return strBuilder.toString(); } - /** - * Returns the names of the inputs in this signature mapped to their expected data type - */ - public Map getInputs() { - Map dataTypeMap = new HashMap<>(); - signatureDef.getInputsMap().forEach((a,b) -> { - dataTypeMap.put(a, b.getDtype()); + private Map buildTensorDescriptionMap(Map dataMapIn) { + Map dataTypeMap = new HashMap<>(); + dataMapIn.forEach((a, b) -> { + dataTypeMap.put(a, new TensorDescription(b.getDtype(), b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray())); }); return dataTypeMap; } /** - * Returns the names of the outputs in this signature mapped to their expected data type + * Returns the names of the inputs in this signature mapped to their expected data type and shape + * @return */ - public Map getOutputs() { - Map dataTypeMap = new HashMap<>(); - signatureDef.getOutputsMap().forEach((a,b) -> { - dataTypeMap.put(a, b.getDtype()); - }); - return dataTypeMap; + public Map getInputs() { + return buildTensorDescriptionMap(signatureDef.getInputsMap()); + } + + /** + * Returns the names of the outputs in this signature mapped to their expected data type and shape + */ + public Map getOutputs() { + return buildTensorDescriptionMap(signatureDef.getOutputsMap()); } Signature(String key, SignatureDef signatureDef) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java index e1436358a68..6c313a60684 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -14,11 +14,14 @@ ==============================================================================*/ package org.tensorflow; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; - import org.junit.jupiter.api.Test; +import org.tensorflow.Signature.TensorDescription; import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.DataType; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; public class SignatureTest { @@ -43,6 +46,29 @@ public void cannotDuplicateInputOutputNames() { } } + @Test + public void getInputsAndOutputs() { + Ops tf = Ops.create(); + Signature builder = Signature.builder() + .input("x", tf.constant(10.0f)) + .output("y", tf.constant(new float[][] {{10.0f, 30.0f}})) + .output("z", tf.constant(20.0f)).build(); + + Map inputs = builder.getInputs(); + assertEquals(inputs.size(), 1); + + Map outputs = builder.getOutputs(); + assertEquals(outputs.size(), 2); + + assertEquals(outputs.get("y").dataType, DataType.DT_FLOAT); + assertEquals(outputs.get("z").dataType, DataType.DT_FLOAT); + assertArrayEquals(outputs.get("y").shape, new long [] {1,2}); + assertArrayEquals(outputs.get("z").shape, new long [] {}); + + Signature emptySignature = Signature.builder().build(); + assertEquals(emptySignature.getInputs().size(), 0); + } + @Test public void emptyMethodNameConvertedToNull() { Signature signature = Signature.builder().key("f").build(); From 86973e30e1b366c822d6809d661a849db805cbcd Mon Sep 17 00:00:00 2001 From: Jordan Bentley Date: Wed, 3 Feb 2021 16:09:24 -0500 Subject: [PATCH 3/3] Expose input and output types from signature as Shape --- .../src/main/java/org/tensorflow/Signature.java | 9 ++++++--- .../src/test/java/org/tensorflow/SignatureTest.java | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index cf22474b55c..66b4dad4132 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -36,9 +36,9 @@ public class Signature { public static class TensorDescription { public final DataType dataType; - public final long[] shape; + public final Shape shape; - public TensorDescription(DataType dataType, long[] shape) { + public TensorDescription(DataType dataType, Shape shape) { this.dataType = dataType; this.shape = shape; } @@ -189,7 +189,10 @@ public String toString() { private Map buildTensorDescriptionMap(Map dataMapIn) { Map dataTypeMap = new HashMap<>(); dataMapIn.forEach((a, b) -> { - dataTypeMap.put(a, new TensorDescription(b.getDtype(), b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray())); + long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); + Shape tensorShape = Shape.of(tensorDims); + dataTypeMap.put(a, new TensorDescription(b.getDtype(), + tensorShape)); }); return dataTypeMap; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java index 6c313a60684..c9740ce4a6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -62,8 +62,8 @@ public void getInputsAndOutputs() { assertEquals(outputs.get("y").dataType, DataType.DT_FLOAT); assertEquals(outputs.get("z").dataType, DataType.DT_FLOAT); - assertArrayEquals(outputs.get("y").shape, new long [] {1,2}); - assertArrayEquals(outputs.get("z").shape, new long [] {}); + assertArrayEquals(outputs.get("y").shape.asArray(), new long [] {1,2}); + assertArrayEquals(outputs.get("z").shape.asArray(), new long [] {}); Signature emptySignature = Signature.builder().build(); assertEquals(emptySignature.getInputs().size(), 0);