Skip to content

Commit 74b4611

Browse files
committed
Rework pointer scopes
Signed-off-by: Ryan Nett <[email protected]>
1 parent b018b80 commit 74b4611

File tree

3 files changed

+172
-145
lines changed

3 files changed

+172
-145
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ public class ConcreteFunction implements AutoCloseable {
7272
* Creates a function by building a new graph.
7373
*
7474
* <p>The {@code functionBuilder} must initialize the function graph from the provided
75-
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output
76-
* tensors on execution.
75+
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors
76+
* and fetch the output tensors on execution.
7777
*
7878
* <p>The function will be the owner of the new graph and its resulting session. Therefore,
79-
* the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will
80-
* be freed once the function is discarded. For example:
79+
* the function must be enclosed properly with a try-with-resources block to guarantee that all
80+
* native resources will be freed once the function is discarded. For example:
8181
*
8282
* <pre>{@code
8383
* public class MyModel {
@@ -112,8 +112,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
112112
* Create a function from a signature and an existing graph.
113113
*
114114
* <p>The function will keep the ownership of the session used to run the graph but not
115-
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For
116-
* example:
115+
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the
116+
* function. For example:
117117
*
118118
* <pre>{@code
119119
* try (Graph g = new Graph()) {
@@ -130,7 +130,7 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
130130
* }</pre>
131131
*
132132
* @param signature signature of the function to create
133-
* @param graph a valid and initialized graph
133+
* @param graph a valid and initialized graph
134134
* @return a new function
135135
*/
136136
public static ConcreteFunction create(Signature signature, Graph graph) {
@@ -141,8 +141,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
141141
* Create a function from a signature and a valid graph session.
142142
*
143143
* <p>The function will not own the session nor its graph, meaning that their lifetime
144-
* can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For
145-
* example:
144+
* can extend beyond the scope of the function. Therefore the function does not need to be closed
145+
* after its usage. For example:
146146
*
147147
* <pre>{@code
148148
* try (Graph g = new Graph()) {
@@ -164,7 +164,7 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
164164
* }</pre>
165165
*
166166
* @param signature signature of the function to create
167-
* @param session a valid session to an initialized graph
167+
* @param session a valid session to an initialized graph
168168
* @return a new function
169169
*/
170170
public static ConcreteFunction create(Signature signature, Session session) {
@@ -220,10 +220,10 @@ public String toString() {
220220

221221

222222
/**
223-
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
224-
* inputs and outputs are keyed by the names set in the {@code Signature}.
223+
* Calls the function in an execution environment, adding it's graph as a function if it isn't
224+
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
225225
*
226-
* @param scope the scope to call the function in
226+
* @param scope the scope to call the function in
227227
* @param arguments the arguments to the call
228228
* @return the outputs of the function
229229
*/
@@ -276,7 +276,8 @@ public Map<String, Operand<?>> call(Scope scope,
276276
String outputName = outputNames.get(i);
277277

278278
if (i > outputList.size()) {
279-
throw new IllegalStateException("Somehow, not all required outputs were returned from the function");
279+
throw new IllegalStateException(
280+
"Somehow, not all required outputs were returned from the function");
280281
}
281282

282283
Operand<?> output = outputList.get(i);
@@ -287,10 +288,10 @@ public Map<String, Operand<?>> call(Scope scope,
287288
}
288289

289290
/**
290-
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
291-
* works for functions with a single input and output.
291+
* Calls the function in an execution environment, adding it's graph as a function if it isn't
292+
* already present. Only works for functions with a single input and output.
292293
*
293-
* @param scope the scope to call the function in
294+
* @param scope the scope to call the function in
294295
* @param argument the argument to the call
295296
* @return the output of the function
296297
*/
@@ -320,8 +321,10 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
320321
*
321322
* <p>Caller is responsible for closing all Tensors.
322323
*
323-
* @param arguments list of tensors to pass in input to the function, mapped by their signature name
324-
* @return output tensors resulting from the execution of the function, mapped by their signature name
324+
* @param arguments list of tensors to pass in input to the function, mapped by their signature
325+
* name
326+
* @return output tensors resulting from the execution of the function, mapped by their signature
327+
* name
325328
*/
326329
public Map<String, Tensor> call(Map<String, Tensor> arguments)
327330
throws IllegalArgumentException {
@@ -348,7 +351,8 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
348351
*
349352
* @param tensor input tensor
350353
* @return output tensor
351-
* @throws IllegalArgumentException if there are multiple input or output parameters defined in the function
354+
* @throws IllegalArgumentException if there are multiple input or output parameters defined in
355+
* the function
352356
*/
353357
public Tensor call(Tensor tensor) throws IllegalArgumentException {
354358
Ops tf = Ops.create();
@@ -358,10 +362,10 @@ public Tensor call(Tensor tensor) throws IllegalArgumentException {
358362
}
359363

360364
/**
361-
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. The
362-
* inputs and outputs are keyed by the names set in the {@code Signature}.
365+
* Calls the function in an execution environment, adding it's graph as a function if it isn't
366+
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
363367
*
364-
* @param tf the scope to call the function in
368+
* @param tf the scope to call the function in
365369
* @param arguments the arguments to the call
366370
* @return the outputs of the function
367371
*/
@@ -370,10 +374,10 @@ public Map<String, Operand<?>> call(Ops tf, Map<String, Operand<?>> arguments) {
370374
}
371375

372376
/**
373-
* Calls the function in an execution environment, adding it's graph as a function if it isn't already present. Only
374-
* works for functions with a single input and output.
377+
* Calls the function in an execution environment, adding it's graph as a function if it isn't
378+
* already present. Only works for functions with a single input and output.
375379
*
376-
* @param tf the scope to call the function in
380+
* @param tf the scope to call the function in
377381
* @param argument the argument to the call
378382
* @return the output of the function
379383
*/
@@ -401,17 +405,23 @@ TF_Function nativeHandle() {
401405
return nativeFunction.getNativeHandle();
402406
}
403407

404-
ConcreteFunction(Signature signature, NativeFunction nativeFunction, Collection<NativeFunction> availableFunctions) {
408+
/**
409+
* All native functions should have deallocators registered
410+
*/
411+
ConcreteFunction(Signature signature, NativeFunction nativeFunction,
412+
Collection<NativeFunction> availableFunctions) {
405413
this(signature, nativeFunction, nativeFunction.getAllDependencies(availableFunctions));
406414
}
407415

408416
/**
409-
* Detects the signature from the handle
417+
* Detects the signature from the handle. Does not close passed functions. All passed functions
418+
* should have deallocators.
410419
*/
411420
static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
412421
Collection<NativeFunction> availableFunctions) {
413422

414-
Signature.Builder builder = Signature.builder().methodName(nativeFunction.getFunctionDef().getSignature().getName())
423+
Signature.Builder builder = Signature.builder()
424+
.methodName(nativeFunction.getFunctionDef().getSignature().getName())
415425
.key(nativeFunction.getName());
416426

417427
for (ArgDef input : nativeFunction.getFunctionDef().getSignature().getInputArgList()) {
@@ -448,19 +458,26 @@ static ConcreteFunction fromNativeHandle(NativeFunction nativeFunction,
448458
private final DataType[] inputDtypes;
449459
private final DataType[] outputDtypes;
450460

451-
private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set<TF_Function> dependencies) {
461+
462+
/**
463+
* All native functions should have deallocators registered
464+
*/
465+
private ConcreteFunction(Signature signature, NativeFunction nativeFunction,
466+
Set<TF_Function> dependencies) {
452467
this.signature = signature;
453468
this.nativeFunction = nativeFunction;
454469
this.dependencies = Collections.unmodifiableSet(dependencies);
455470

456-
if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
471+
if (this.signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature()
472+
.getInputArgCount()) {
457473
throw new IllegalArgumentException(
458474
"Signature must have the same number of inputs as the native function. Expected "
459475
+ nativeFunction.getFunctionDef().getSignature().getInputArgCount() + ", got "
460476
+ this.signature.getInputs().size());
461477
}
462478

463-
if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
479+
if (this.signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature()
480+
.getOutputArgCount()) {
464481
throw new IllegalArgumentException(
465482
"New signature must have the same number of outputs as the native function. Expected "
466483
+ nativeFunction.getFunctionDef().getSignature().getOutputArgCount() + ", got "
@@ -471,7 +488,8 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
471488
.toArray(DataType[]::new);
472489

473490
List<DataType> inputs = Arrays.asList(inputDtypes);
474-
List<DataType> nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
491+
List<DataType> nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList()
492+
.stream()
475493
.map(ArgDef::getType)
476494
.collect(Collectors.toList());
477495

@@ -481,10 +499,12 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
481499
+ nativeInputs + ", got " + inputs);
482500
}
483501

484-
outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
502+
outputDtypes = signature().getOutputs().values().stream().map(x -> x.dataType)
503+
.toArray(DataType[]::new);
485504

486505
List<DataType> outputs = Arrays.asList(outputDtypes);
487-
List<DataType> nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
506+
List<DataType> nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList()
507+
.stream()
488508
.map(ArgDef::getType)
489509
.collect(Collectors.toList());
490510

@@ -498,25 +518,26 @@ private ConcreteFunction(Signature signature, NativeFunction nativeFunction, Set
498518
try (PointerScope scope = new PointerScope()) {
499519
this.scope = scope;
500520
scope.extend();
501-
this.nativeFunction.getNativeHandle().withDeallocatorInScope();
502-
this.dependencies.forEach(TF_Function::withDeallocatorInScope);
521+
scope.attach(this.nativeFunction.getNativeHandle());
522+
this.dependencies.forEach(scope::attach);
503523
}
504524
}
505525

506526
/**
507-
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
508-
* JIT is extremely non-obvious.
509-
*
510-
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
511-
* 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
527+
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
528+
* how to enable XLA JIT is extremely non-obvious.
529+
* <p>
530+
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
531+
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
512532
*/
513533
private void makeJit() {
514534
try (PointerScope scope = new PointerScope()) {
515535
byte[] bytes = AttrValue.newBuilder().setB(true).build().toByteArray();
516536
BytePointer trueValue = new BytePointer(bytes);
517537

518538
TF_Status status1 = TF_Status.newStatus();
519-
TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length, status1);
539+
TF_FunctionSetAttrValueProto(nativeHandle(), "_XlaMustCompile", trueValue, bytes.length,
540+
status1);
520541
status1.throwExceptionIfNotOK();
521542

522543
TF_Status status2 = TF_Status.newStatus();
@@ -592,9 +613,11 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
592613
inputs.forEach(input -> ops.remove(input.op()));
593614

594615
ops.forEach(x -> {
595-
if(x.type().equals(Placeholder.OP_NAME) || x.type().equals(PlaceholderWithDefault.OP_NAME)){
596-
throw new IllegalArgumentException("Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), "
597-
+ "they also depend on \"" + x + "\"");
616+
if (x.type().equals(Placeholder.OP_NAME) || x.type()
617+
.equals(PlaceholderWithDefault.OP_NAME)) {
618+
throw new IllegalArgumentException(
619+
"Can't calculate outputs (" + outputs + ") from inputs (" + inputs + "), "
620+
+ "they also depend on \"" + x + "\"");
598621
}
599622
});
600623

@@ -629,12 +652,15 @@ private static ConcreteFunction buildFromGraph(Graph graph, Signature signature)
629652
resolveToOutput(graph, outputs),
630653
null,
631654
null,
632-
new BytePointer(signature.methodName() != null ? signature.methodName() : "Method " + signature.key()),
655+
new BytePointer(signature.methodName() != null ? signature.methodName()
656+
: "Method " + signature.key()),
633657
status
634658
);
635659

660+
handle.withDeallocator();
636661
status.throwExceptionIfNotOK();
637-
return new ConcreteFunction(signature, new NativeFunction(handle), graph.getNativeFunctions());
662+
return new ConcreteFunction(signature, new NativeFunction(handle),
663+
graph.getNativeFunctions(scope));
638664
}
639665
}
640666
}

0 commit comments

Comments
 (0)