diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index ea32d1fff13..66b4dad4132 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -15,9 +15,11 @@ */ package org.tensorflow; +import java.util.HashMap; import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.proto.framework.TensorShapeProto; @@ -32,6 +34,16 @@ public class Signature { /** The default signature key, when not provided */ public static final String DEFAULT_KEY = "serving_default"; + public static class TensorDescription { + public final DataType dataType; + public final Shape shape; + + public TensorDescription(DataType dataType, Shape shape) { + this.dataType = dataType; + this.shape = shape; + } + } + /** * Builds a new function signature. */ @@ -174,6 +186,32 @@ public String toString() { return strBuilder.toString(); } + private Map buildTensorDescriptionMap(Map dataMapIn) { + Map dataTypeMap = new HashMap<>(); + dataMapIn.forEach((a, b) -> { + long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray(); + Shape tensorShape = Shape.of(tensorDims); + dataTypeMap.put(a, new TensorDescription(b.getDtype(), + tensorShape)); + }); + return dataTypeMap; + } + + /** + * Returns the names of the inputs in this signature mapped to their expected data type and shape + * @return + */ + public Map getInputs() { + return buildTensorDescriptionMap(signatureDef.getInputsMap()); + } + + /** + * Returns the names of the outputs in this signature mapped to their expected data type and shape + */ + public Map getOutputs() { + return buildTensorDescriptionMap(signatureDef.getOutputsMap()); + } + Signature(String key, SignatureDef signatureDef) { this.key = key; this.signatureDef = signatureDef; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java index e1436358a68..c9740ce4a6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -14,11 +14,14 @@ ==============================================================================*/ package org.tensorflow; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; - import org.junit.jupiter.api.Test; +import org.tensorflow.Signature.TensorDescription; import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.DataType; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; public class SignatureTest { @@ -43,6 +46,29 @@ public void cannotDuplicateInputOutputNames() { } } + @Test + public void getInputsAndOutputs() { + Ops tf = Ops.create(); + Signature builder = Signature.builder() + .input("x", tf.constant(10.0f)) + .output("y", tf.constant(new float[][] {{10.0f, 30.0f}})) + .output("z", tf.constant(20.0f)).build(); + + Map inputs = builder.getInputs(); + assertEquals(inputs.size(), 1); + + Map outputs = builder.getOutputs(); + assertEquals(outputs.size(), 2); + + assertEquals(outputs.get("y").dataType, DataType.DT_FLOAT); + assertEquals(outputs.get("z").dataType, DataType.DT_FLOAT); + assertArrayEquals(outputs.get("y").shape.asArray(), new long [] {1,2}); + assertArrayEquals(outputs.get("z").shape.asArray(), new long [] {}); + + Signature emptySignature = Signature.builder().build(); + assertEquals(emptySignature.getInputs().size(), 0); + } + @Test public void emptyMethodNameConvertedToNull() { Signature signature = Signature.builder().key("f").build();