Skip to content

Commit 3a46d63

Browse files
committed
Fold name handling into ConcreteFunction, fix tests
Signed-off-by: Ryan Nett <[email protected]>
1 parent 1fe9f53 commit 3a46d63

File tree

12 files changed

+136
-437
lines changed

12 files changed

+136
-437
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 8 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import org.tensorflow.ConcreteFunction;
24+
import org.tensorflow.DefinedFunction;
2425
import org.tensorflow.DeviceSpec;
2526
import org.tensorflow.EagerSession;
2627
import org.tensorflow.ExecutionEnvironment;
27-
import org.tensorflow.GraphFunction;
28-
import org.tensorflow.NamedGraphFunction;
2928
import org.tensorflow.Operand;
3029
import org.tensorflow.ndarray.BooleanNdArray;
3130
import org.tensorflow.ndarray.ByteNdArray;
@@ -1072,100 +1071,18 @@ public Bucketize bucketize(Operand<? extends TNumber> input, List<Float> boundar
10721071
}
10731072

10741073
/**
1075-
* Call {@code function}, adding it to the execution environment if it isn't already present. The inputs and outputs
1076-
* are keyed by the names set in the {@code ConcreteFunction}'s {@code Signature}.
1077-
*
1078-
* @param function the function to call
1079-
* @param inputs the inputs to the function
1080-
* @return the outputs of the function
1081-
*/
1082-
public Map<String, Operand<?>> callConcreteFunction(ConcreteFunction function,
1083-
Operand<?>... inputs) {
1084-
return Function.callConcreteFunction(scope, function, inputs);
1085-
}
1086-
1087-
/**
1088-
* Call {@code function}, adding it to the execution environment if it isn't already present. The inputs and outputs
1089-
* are keyed by the names set in the {@code ConcreteFunction}'s {@code Signature}.
1090-
*
1091-
* @param function the function to call
1092-
* @param inputs the inputs to the function
1093-
* @return the outputs of the function
1094-
*/
1095-
public Map<String, Operand<?>> callConcreteFunction(ConcreteFunction function,
1096-
List<Operand<?>> inputs) {
1097-
return Function.callConcreteFunction(scope, function, inputs);
1098-
}
1099-
1100-
/**
1101-
* Call {@code function}, adding it to the execution environment if it isn't already present. The inputs and outputs
1102-
* are keyed by the names set in the {@code ConcreteFunction}'s {@code Signature}.
1103-
*
1104-
* @param function the function to call
1105-
* @param inputs the inputs to the function
1106-
* @return the outputs of the function
1107-
*/
1108-
public Map<String, Operand<?>> callConcreteFunction(ConcreteFunction function,
1109-
Map<String, Operand<?>> inputs) {
1110-
return Function.callConcreteFunction(scope, function, inputs);
1111-
}
1112-
1113-
/**
1114-
* Call {@code function}, adding it to the execution environment if it isn't already present.
1115-
*
1116-
* @param function the function to call
1117-
* @param inputs the inputs to the function
1118-
* @return the outputs of the function
1074+
* empty
11191075
*/
1120-
public List<Operand<?>> callFunction(GraphFunction function, List<Operand<?>> inputs) {
1121-
return Function.callFunction(scope, function, inputs);
1076+
public Map<String, Operand<?>> call(ConcreteFunction function,
1077+
Map<String, Operand<?>> arguments) {
1078+
return Function.call(scope, function, arguments);
11221079
}
11231080

11241081
/**
1125-
* Call {@code function}, adding it to the execution environment if it isn't already present.
1126-
*
1127-
* @param function the function to call
1128-
* @param inputs the inputs to the function
1129-
* @return the outputs of the function
1130-
*/
1131-
public List<Operand<?>> callFunction(GraphFunction function, Operand<?>... inputs) {
1132-
return Function.callFunction(scope, function, inputs);
1133-
}
1134-
1135-
/**
1136-
* Call {@code function}, adding it to the execution environment if it isn't already present.
1137-
*
1138-
* @param function the function to call
1139-
* @param inputs the inputs to the function
1140-
* @return the outputs of the function
1141-
*/
1142-
public Map<String, Operand<?>> callNamedFunction(NamedGraphFunction function,
1143-
Map<String, Operand<?>> inputs) {
1144-
return Function.callNamedFunction(scope, function, inputs);
1145-
}
1146-
1147-
/**
1148-
* Call {@code function}, adding it to the execution environment if it isn't already present.
1149-
*
1150-
* @param function the function to call
1151-
* @param inputs the inputs to the function
1152-
* @return the outputs of the function
1153-
*/
1154-
public Map<String, Operand<?>> callNamedFunction(NamedGraphFunction function,
1155-
Operand<?>... inputs) {
1156-
return Function.callNamedFunction(scope, function, inputs);
1157-
}
1158-
1159-
/**
1160-
* Call {@code function}, adding it to the execution environment if it isn't already present.
1161-
*
1162-
* @param function the function to call
1163-
* @param inputs the inputs to the function
1164-
* @return the outputs of the function
1082+
* empty
11651083
*/
1166-
public Map<String, Operand<?>> callNamedFunction(NamedGraphFunction function,
1167-
List<Operand<?>> inputs) {
1168-
return Function.callNamedFunction(scope, function, inputs);
1084+
public List<Operand<?>> call(DefinedFunction function, List<Operand<?>> arguments) {
1085+
return Function.call(scope, function, arguments);
11691086
}
11701087

11711088
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ public Output<?>[] outputList(int idx, int length) {
3939

4040
@Override
4141
public <T extends TType> Output<T> output(int idx) {
42+
if (getUnsafeNativeHandle(idx) != null && !getUnsafeNativeHandle(idx).isNull()) {
43+
int numOutputs = this.numOutputs();
44+
if (idx >= numOutputs) {
45+
throw new IndexOutOfBoundsException(
46+
"Can't get output with index " + idx + ", this op only has " + numOutputs + " outputs.");
47+
}
48+
49+
if (idx < 0) {
50+
throw new IndexOutOfBoundsException("Can't get output with index < 0.");
51+
}
52+
}
4253
return new Output<>(this, idx);
4354
}
4455

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
import java.io.IOException;
1919
import java.util.HashMap;
20+
import java.util.LinkedHashMap;
2021
import java.util.List;
2122
import java.util.ListIterator;
2223
import java.util.Map;
2324
import java.util.function.Function;
25+
import java.util.stream.Collectors;
2426
import org.tensorflow.op.Ops;
2527
import org.tensorflow.proto.framework.SignatureDef;
2628
import org.tensorflow.proto.framework.TensorInfo;
@@ -209,7 +211,7 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
209211
* @return the outputs of the function
210212
*/
211213
public Map<String, Operand<?>> call(Ops tf, Map<String, Operand<?>> arguments) {
212-
return tf.callConcreteFunction(this, arguments);
214+
return tf.call(this, arguments);
213215
}
214216

215217
/**
@@ -255,14 +257,18 @@ public Operand<?> call(Ops tf, Operand<?> argument) {
255257
throw new IllegalArgumentException(
256258
String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
257259
}
260+
String inputName = signatureDef.getInputsMap().keySet().iterator().next();
258261

259262
if (signatureDef.getOutputsCount() != 1) {
260263
throw new IllegalArgumentException(
261264
String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
262265
}
263-
String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName();
266+
String outputName = signatureDef.getOutputsMap().keySet().iterator().next();
267+
268+
Map<String, Operand<?>> inputMap = new LinkedHashMap<>();
269+
inputMap.put(inputName, argument);
264270

265-
return tf.callConcreteFunction(this, argument).get(outputNodeName);
271+
return call(tf, inputMap).get(outputName);
266272
}
267273

268274
/**
@@ -298,12 +304,29 @@ public Graph graph() {
298304
return graph;
299305
}
300306

307+
private DefinedFunction makeGraphFunction() {
308+
String description = signature().methodName();
309+
if (description == null) {
310+
description = signature().key();
311+
}
312+
313+
List<Operand<?>> inputList = signature.getInputs().values().stream()
314+
.map((it) -> graph.outputOrError(it.name))
315+
.collect(Collectors.toUnmodifiableList());
316+
317+
List<Operand<?>> outputList = signature.getOutputs().values().stream()
318+
.map((it) -> graph.outputOrError(it.name))
319+
.collect(Collectors.toUnmodifiableList());
320+
321+
return DefinedFunction.create(graph, signature.key(), true, inputList, outputList, description);
322+
}
323+
301324
/**
302325
* Get the graph as a function. The graph function will be closed when this is.
303326
*/
304-
public NamedGraphFunction function() {
327+
public DefinedFunction function() {
305328
if (function == null) {
306-
function = GraphFunction.create(this, true);
329+
function = makeGraphFunction();
307330
}
308331
return function;
309332
}
@@ -332,7 +355,7 @@ private enum Ownership {
332355

333356
private final Graph graph;
334357
private final Session session;
335-
private NamedGraphFunction function = null;
358+
private DefinedFunction function = null;
336359
private final Signature signature;
337360
private final Ownership ownership;
338361

0 commit comments

Comments
 (0)