Skip to content

Commit bd12f37

Browse files
committed
Bug fixes
Signed-off-by: Ryan Nett <[email protected]>
1 parent 8db5802 commit bd12f37

File tree

6 files changed

+63
-25
lines changed

6 files changed

+63
-25
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
package org.tensorflow.op;
1919

2020
import java.nio.charset.Charset;
21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Map;
2324
import org.tensorflow.ConcreteFunction;
2425
import org.tensorflow.DeviceSpec;
2526
import org.tensorflow.EagerSession;
2627
import org.tensorflow.ExecutionEnvironment;
2728
import org.tensorflow.Operand;
29+
import org.tensorflow.Operation;
2830
import org.tensorflow.ndarray.BooleanNdArray;
2931
import org.tensorflow.ndarray.ByteNdArray;
3032
import org.tensorflow.ndarray.DoubleNdArray;
@@ -8254,6 +8256,33 @@ public Ops withControlDependencies(Iterable<Op> controls) {
82548256
return new Ops(scope.withControlDependencies(controls));
82558257
}
82568258

8259+
/**
8260+
* Returns an API that adds operations to the graph with the provided control dependencies.
8261+
*
8262+
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
8263+
*/
8264+
public Ops withControlDependencies(Op... controls) {
8265+
return withControlDependencies(Arrays.asList(controls));
8266+
}
8267+
8268+
/**
8269+
* Returns an API that adds operations to the graph with the provided control dependencies.
8270+
*
8271+
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
8272+
*/
8273+
public Ops withControlDependencyOps(Iterable<Operation> controls) {
8274+
return new Ops(scope.withControlDependencyOps(controls));
8275+
}
8276+
8277+
/**
8278+
* Returns an API that adds operations to the graph with the provided control dependencies.
8279+
*
8280+
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
8281+
*/
8282+
public Ops withControlDependencyOps(Operation... controls) {
8283+
return withControlDependencyOps(Arrays.asList(controls));
8284+
}
8285+
82578286
/**
82588287
* Returns the current {@link Scope scope} of this API
82598288
*/

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,8 @@ private static Object[] whileLoop(
12441244
private static SaverDef addVariableSaver(Graph graph) {
12451245
Ops tf = Ops.create(graph).withSubScope(SAVER_DEF_SCOPE);
12461246

1247+
// TODO handle resource variables, too
1248+
12471249
List<String> varNames = new ArrayList<>();
12481250
List<Operand<?>> varOutputs = new ArrayList<>();
12491251
List<Class<? extends TType>> varTypes = new ArrayList<>();

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ public void export() throws IOException {
281281
functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef()));
282282

283283
if (!functions.containsKey(INIT_OP_SIGNATURE_KEY)) {
284+
284285
metaGraphDef.putSignatureDef(
285286
INIT_OP_SIGNATURE_KEY,
286287
SignatureDef.newBuilder()
@@ -388,7 +389,10 @@ public Session session() {
388389

389390
/** Return the signature of all functions available in this saved model. */
390391
public List<Signature> signatures() {
391-
return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList());
392+
return functions.values().stream()
393+
.map(SessionFunction::signature)
394+
.filter(signature -> !signature.key().equals(INIT_OP_SIGNATURE_KEY))
395+
.collect(Collectors.toList());
392396
}
393397

394398
/**
@@ -593,6 +597,9 @@ private static SavedModelBundle load(
593597
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
594598
}
595599
}
600+
bundle.session.initialize();
601+
602+
// bundle.session.restore(exportDir + "/variables/variables");
596603

597604
return bundle;
598605
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,8 @@ public void exportFunctionWithVariables() throws IOException {
180180
assertEquals("save/control_dependency", saverDef.getSaveTensorName());
181181
assertEquals("save/restore_all", saverDef.getRestoreOpName());
182182

183-
assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
184-
assertEquals(
185-
Signature.DEFAULT_KEY,
186-
savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next());
183+
assertEquals(2, savedModel.metaGraphDef().getSignatureDefCount());
184+
assertTrue(savedModel.metaGraphDef().getSignatureDefMap().containsKey(Signature.DEFAULT_KEY));
187185

188186
TensorFunction function = savedModel.function(Signature.DEFAULT_KEY);
189187
assertNotNull(function);

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
/*
2-
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
33
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
==============================================================================
16-
*/
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
1717
package org.tensorflow;
1818

1919
import com.squareup.javapoet.ArrayTypeName;
2020
import com.squareup.javapoet.ClassName;
2121
import com.squareup.javapoet.ParameterizedTypeName;
2222
import com.squareup.javapoet.TypeName;
23-
import java.util.Arrays;
2423

2524
public class Names {
2625

@@ -52,9 +51,12 @@ public class Names {
5251
public static final ClassName RawOp = ClassName.get(OpPackage, "RawOp");
5352
public static final ClassName Operation = ClassName.get(TensorflowPackage, "Operation");
5453
public static final ClassName Operands = ClassName.get(OpPackage, "Operands");
55-
public static final ClassName OperationBuilder = ClassName.get(TensorflowPackage, "OperationBuilder");
56-
public static final TypeName IterableOp = ParameterizedTypeName.get(ClassName.get(Iterable.class), Op);
57-
public static final TypeName IterableOperation = ParameterizedTypeName.get(ClassName.get(Iterable.class), Operation);
54+
public static final ClassName OperationBuilder =
55+
ClassName.get(TensorflowPackage, "OperationBuilder");
56+
public static final TypeName IterableOp =
57+
ParameterizedTypeName.get(ClassName.get(Iterable.class), Op);
58+
public static final TypeName IterableOperation =
59+
ParameterizedTypeName.get(ClassName.get(Iterable.class), Operation);
5860
public static final TypeName ArrayOp = ArrayTypeName.of(Op);
5961
public static final TypeName ArrayOperation = ArrayTypeName.of(Operation);
6062

@@ -63,7 +65,8 @@ public class Names {
6365

6466
public static final ClassName Shape = ClassName.get(TensorflowPackage + ".ndarray", "Shape");
6567
public static final ClassName Tensor = ClassName.get(TensorflowPackage, "Tensor");
66-
public static final ClassName ConcreteFunction = ClassName.get(TensorflowPackage, "ConcreteFunction");
68+
public static final ClassName ConcreteFunction =
69+
ClassName.get(TensorflowPackage, "ConcreteFunction");
6770

6871
public static final ClassName Scope = ClassName.get(OpPackage, "Scope");
6972
public static final TypeName DeviceSpec = ClassName.get(TensorflowPackage, "DeviceSpec");
@@ -75,5 +78,4 @@ public class Names {
7578

7679
public static final TypeName String = ClassName.get(String.class);
7780
public static final ClassName Arrays = ClassName.get(java.util.Arrays.class);
78-
7981
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
625625
.addParameter(Names.ArrayOp, "controls")
626626
.varargs()
627627
.returns(Names.Ops)
628-
.addStatement("return withControlDependencies(%T.asList(controls))", Names.Arrays)
628+
.addStatement("return withControlDependencies($T.asList(controls))", Names.Arrays)
629629
.addJavadoc(
630630
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
631631
+ "@see {@link $T#withControlDependencies(Iterable<Op<?>>)}\n",
@@ -650,7 +650,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
650650
.addParameter(Names.ArrayOperation, "controls")
651651
.varargs()
652652
.returns(Names.Ops)
653-
.addStatement("return withControlDependencyOps(%T.asList(controls))", Names.Arrays)
653+
.addStatement("return withControlDependencyOps($T.asList(controls))", Names.Arrays)
654654
.addJavadoc(
655655
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
656656
+ "@see {@link $T#withControlDependencyOps(Iterable<Operation>)}\n",

0 commit comments

Comments
 (0)