Skip to content

Commit a3803e5

Browse files
committed
Support dependencies
Signed-off-by: Ryan Nett <[email protected]>
1 parent 82a1f2d commit a3803e5

File tree

8 files changed

+419
-234
lines changed

8 files changed

+419
-234
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 112 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
*/
1616
package org.tensorflow;
1717

18-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionName;
1918
import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionSetAttrValueProto;
20-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_FunctionToFunctionDef;
2119
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;
22-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrValueProto;
2320

24-
import com.google.protobuf.InvalidProtocolBufferException;
2521
import java.io.IOException;
2622
import java.util.ArrayList;
23+
import java.util.Collection;
2724
import java.util.Collections;
2825
import java.util.HashSet;
2926
import java.util.LinkedHashMap;
30-
import java.util.LinkedHashSet;
3127
import java.util.List;
3228
import java.util.Map;
3329
import java.util.Set;
@@ -38,7 +34,6 @@
3834
import org.bytedeco.javacpp.PointerPointer;
3935
import org.bytedeco.javacpp.PointerScope;
4036
import org.tensorflow.Graph.Reference;
41-
import org.tensorflow.internal.c_api.TF_Buffer;
4237
import org.tensorflow.internal.c_api.TF_Function;
4338
import org.tensorflow.internal.c_api.TF_Operation;
4439
import org.tensorflow.internal.c_api.TF_Output;
@@ -49,7 +44,6 @@
4944
import org.tensorflow.proto.framework.AttrValue;
5045
import org.tensorflow.proto.framework.DataType;
5146
import org.tensorflow.proto.framework.FunctionDef;
52-
import org.tensorflow.proto.framework.NodeDef;
5347
import org.tensorflow.proto.framework.OpDef.ArgDef;
5448
import org.tensorflow.proto.framework.SignatureDef;
5549
import org.tensorflow.proto.framework.TensorInfo;
@@ -186,9 +180,35 @@ public Signature signature() {
186180
* Get the name of the function.
187181
*/
188182
public String getNativeFunctionName() {
189-
try (PointerScope scope = new PointerScope()) {
190-
return TF_FunctionName(nativeHandle()).getString();
191-
}
183+
return nativeFunction.getName();
184+
}
185+
186+
/**
187+
* Get the {@link FunctionDef} proto.
188+
*/
189+
public FunctionDef getFunctionDef() {
190+
return nativeFunction.getFunctionDef();
191+
}
192+
193+
/**
194+
* Get whether the function is stateful.
195+
*/
196+
public boolean isStateful() {
197+
return nativeFunction.isStateful();
198+
}
199+
200+
Set<TF_Function> getDependencies() {
201+
return dependencies;
202+
}
203+
204+
@Override
205+
public void close() {
206+
scope.close();
207+
}
208+
209+
@Override
210+
public String toString() {
211+
return signature.toString();
192212
}
193213

194214
public static final String CALL_OP = "PartitionedCall";
@@ -214,7 +234,7 @@ public Map<String, Operand<?>> call(Scope scope,
214234
String displayName = Scope.isValidOpName(name) ? name : "FunctionCall";
215235

216236
OperationBuilder opBuilder = scope.env()
217-
.opBuilder(stateful ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));
237+
.opBuilder(isStateful() ? STATEFUL_CALL_OP : CALL_OP, scope.makeOpName(displayName));
218238

219239
opBuilder.addInputList(inputList.stream().map(Operand::asOutput).toArray(Output[]::new));
220240

@@ -357,20 +377,6 @@ public void save(String exportDir) throws IOException {
357377
SavedModelBundle.exporter(exportDir).withFunction(this).export();
358378
}
359379

360-
public boolean isStateful() {
361-
return stateful;
362-
}
363-
364-
@Override
365-
public void close() {
366-
scope.close();
367-
}
368-
369-
@Override
370-
public String toString() {
371-
return signature.toString();
372-
}
373-
374380
/**
375381
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
376382
* JIT is extremely non-obvious.
@@ -394,10 +400,10 @@ private void makeJit() {
394400
}
395401

396402
TF_Function nativeHandle() {
397-
if (nativeHandle.isNull()) {
403+
if (nativeFunction.getNativeHandle().isNull()) {
398404
throw new IllegalStateException("Function has been closed");
399405
}
400-
return nativeHandle;
406+
return nativeFunction.getNativeHandle();
401407
}
402408

403409
/**
@@ -410,43 +416,93 @@ TF_Function gradNativeHandle() {
410416
}
411417

412418
private final Signature signature;
413-
private final TF_Function nativeHandle;
419+
private final NativeFunction nativeFunction;
414420
private final PointerScope scope;
415-
private final boolean stateful;
421+
private final Set<TF_Function> dependencies;
422+
423+
ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection<NativeFunction> availableFunctions) {
424+
this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions));
425+
}
426+
427+
private static boolean dataTypesMatch(List<DataType> a, List<DataType> b) {
428+
if (a.size() != b.size()) {
429+
return false;
430+
}
431+
432+
for (int i = 0; i < a.size(); i++) {
433+
DataType aType = a.get(i);
434+
DataType bType = b.get(i);
435+
436+
if (aType != DataType.DT_INVALID && bType != DataType.DT_INVALID && !a.equals(b)) {
437+
return false;
438+
}
439+
}
416440

417-
ConcreteFunction(Signature signature, TF_Function nativeHandle, boolean stateful) {
441+
return true;
442+
}
443+
444+
private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set<TF_Function> dependencies) {
418445
this.signature = signature;
446+
this.nativeFunction = nativeFunction;
447+
this.dependencies = Collections.unmodifiableSet(dependencies);
448+
449+
if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
450+
throw new IllegalArgumentException(
451+
"Signature must have the same number of inputs as the native function. Expected "
452+
+ nativeFunction.getFunctionDef().getSignature().getInputArgCount() + ", got "
453+
+ this.signature.getInputs().size());
454+
}
455+
456+
if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
457+
throw new IllegalArgumentException(
458+
"New signature must have the same number of outputs as the native function. Expected "
459+
+ nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + ", got "
460+
+ this.signature.getOutputs().size());
461+
}
462+
463+
List<DataType> inputs = this.signature.getInputs().values().stream().map(x -> x.dataType)
464+
.collect(Collectors.toList());
465+
List<DataType> nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
466+
.map(ArgDef::getType)
467+
.collect(Collectors.toList());
468+
469+
if (!dataTypesMatch(inputs, nativeInputs)) {
470+
throw new IllegalArgumentException(
471+
"Data types of the signature's inputs must match the native function's (in order). Expected "
472+
+ nativeInputs + ", got " + inputs);
473+
}
474+
475+
List<DataType> outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType)
476+
.collect(Collectors.toList());
477+
List<DataType> nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
478+
.map(ArgDef::getType)
479+
.collect(Collectors.toList());
480+
481+
if (!dataTypesMatch(outputs, nativeOutputs)) {
482+
throw new IllegalArgumentException(
483+
"Data types of the signature's outputs must match the native function's (in order). Expected "
484+
+ nativeOutputs + ", got "
485+
+ outputs);
486+
}
487+
419488
try (PointerScope scope = new PointerScope()) {
420-
scope.extend();
421-
this.nativeHandle = nativeHandle.withDeallocator();
422-
scope.attach(nativeHandle);
423489
this.scope = scope;
490+
scope.extend();
491+
this.nativeFunction.getNativeHandle().withDeallocatorInScope();
492+
this.dependencies.forEach(TF_Function::withDeallocatorInScope);
424493
}
425-
this.stateful = stateful;
426494
}
427495

428496
/**
429497
* Detects the signature from the handle
430498
*/
431-
static ConcreteFunction fromNativeHandle(TF_Function function) {
499+
static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
500+
Collection<NativeFunction> availableFunctions) {
432501

433-
FunctionDef funcDef = null;
434-
try (PointerScope scope = new PointerScope()) {
435-
TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
436-
TF_Status status2 = TF_Status.newStatus();
437-
TF_FunctionToFunctionDef(function, funcDefBuffer, status2);
438-
status2.throwExceptionIfNotOK();
439-
try {
440-
funcDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
441-
} catch (InvalidProtocolBufferException e) {
442-
throw new IllegalStateException("Failed to parse FunctionDef proto", e);
443-
}
444-
}
445-
446-
Signature.Builder builder = Signature.builder().methodName(funcDef.getSignature().getName())
447-
.key(TF_FunctionName(function).getString());
502+
Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName())
503+
.key(nativeFunction.getName());
448504

449-
for (ArgDef input : funcDef.getSignature().getInputArgList()) {
505+
for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
450506
TensorInfo info = TensorInfo.newBuilder()
451507
.setDtype(input.getType())
452508
.setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
@@ -456,7 +512,7 @@ static ConcreteFunction fromNativeHandle(TF_Function function) {
456512
builder.input(input.getName(), info);
457513
}
458514

459-
for (ArgDef outputDef : funcDef.getSignature().getOutputArgList()) {
515+
for (ArgDef outputDef : nativeFunction.getFunctionDef().getSignature().getOutputArgList()) {
460516
TensorInfo info = TensorInfo.newBuilder()
461517
.setDtype(outputDef.getType())
462518
.setTensorShape(TensorShapeProto.newBuilder().setUnknownRank(true).build())
@@ -468,8 +524,8 @@ static ConcreteFunction fromNativeHandle(TF_Function function) {
468524

469525
return new ConcreteFunction(
470526
builder.build(),
471-
function,
472-
funcDef.getNodeDefList().stream().anyMatch(x -> TensorFlow.isOpStateful(x.getOp()))
527+
nativeFunction,
528+
availableFunctions
473529
);
474530
}
475531

@@ -559,91 +615,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
559615
);
560616

561617
status.throwExceptionIfNotOK();
562-
return new ConcreteFunction(signature, handle, ops.stream().anyMatch(x -> TensorFlow.isOpStateful(x.type())));
618+
return new ConcreteFunction(signature, new NativeFunction(handle), graph.getNativeFunctions());
563619
}
564620
}
565621

566622
ConcreteFunction withNewSignature(Signature signature) {
567-
if (this.signature.getInputs().size() != signature.getInputs().size()) {
568-
throw new IllegalArgumentException(
569-
"New signature must have the same number of inputs. Expected " + this.signature.getInputs().size() + ", got "
570-
+ signature.getInputs().size());
571-
}
572-
573-
if (this.signature.getOutputs().size() != signature.getOutputs().size()) {
574-
throw new IllegalArgumentException(
575-
"New signature must have the same number of inputs. Expected " + this.signature.getInputs().size() + ", got "
576-
+ signature.getInputs().size());
577-
}
578-
579-
List<DataType> inputs = this.signature.getInputs().values().stream().map(x -> x.dataType)
580-
.collect(Collectors.toList());
581-
List<DataType> newInputs = signature.getInputs().values().stream().map(x -> x.dataType)
582-
.collect(Collectors.toList());
583-
584-
if (!inputs.equals(newInputs)) {
585-
throw new IllegalArgumentException(
586-
"Data types of the new signature's inputs (in order) must match. Expected " + inputs + ", got " + newInputs);
587-
}
588-
589-
List<DataType> outputs = this.signature.getOutputs().values().stream().map(x -> x.dataType)
590-
.collect(Collectors.toList());
591-
List<DataType> newOutputs = signature.getOutputs().values().stream().map(x -> x.dataType)
592-
.collect(Collectors.toList());
593-
594-
if (!outputs.equals(newOutputs)) {
595-
throw new IllegalArgumentException(
596-
"Data types of the new signature's outputs (in order) must match. Expected " + outputs + ", got "
597-
+ newOutputs);
598-
}
599-
600-
return new ConcreteFunction(signature, nativeHandle, stateful);
601-
}
602-
603-
/**
604-
* Returns the function name if {@code op} is a function call op, or null otherwise.
605-
*/
606-
static String findFunctionCall(GraphOperation op) {
607-
if (op.type().equals(STATEFUL_CALL_OP) || op.type().equals(CALL_OP)) {
608-
try (PointerScope scope = new PointerScope()) {
609-
TF_Status status = TF_Status.newStatus();
610-
TF_Buffer buff = TF_Buffer.newBuffer();
611-
TF_OperationGetAttrValueProto(op.getUnsafeNativeHandle(), "f", buff, status);
612-
status.throwExceptionIfNotOK();
613-
AttrValue def = AttrValue.parseFrom(buff.dataAsByteBuffer());
614-
615-
return def.getFunc().getName();
616-
} catch (InvalidProtocolBufferException e) {
617-
return null;
618-
}
619-
}
620-
621-
return null;
622-
}
623-
624-
FunctionDef functionDef() {
625-
try (PointerScope scope = new PointerScope()) {
626-
TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
627-
TF_Status status2 = TF_Status.newStatus();
628-
TF_FunctionToFunctionDef(nativeHandle(), funcDefBuffer, status2);
629-
status2.throwExceptionIfNotOK();
630-
try {
631-
return FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
632-
} catch (InvalidProtocolBufferException e) {
633-
throw new IllegalStateException("Failed to parse FunctionDef proto", e);
634-
}
635-
}
636-
}
637-
638-
static List<String> findDependencies(FunctionDef def) {
639-
Set<String> deps = new LinkedHashSet<>();
640-
641-
for (NodeDef node : def.getNodeDefList()) {
642-
if (node.getOp().equals(CALL_OP) || node.getOp().equals(STATEFUL_CALL_OP)) {
643-
deps.add(node.getAttrMap().get("f").getFunc().getName());
644-
}
645-
}
646-
647-
return new ArrayList<>(deps);
623+
return new ConcreteFunction(signature, nativeFunction, dependencies);
648624
}
649625
}

0 commit comments

Comments
 (0)