Skip to content

Commit 1fe9f53

Browse files
committed
Formatting, comments, and output bounds checking (and the codegen I apparently forgot)
Signed-off-by: Ryan Nett <[email protected]>
1 parent 14c5fa1 commit 1fe9f53

17 files changed

+783
-543
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,64 +1072,99 @@ public Bucketize bucketize(Operand<? extends TNumber> input, List<Float> boundar
10721072
}
10731073

10741074
/**
1075-
* empty
1075+
* Call {@code function}, adding it to the execution environment if it isn't already present. The inputs and outputs
1076+
* are keyed by the names set in the {@code ConcreteFunction}'s {@code Signature}.
1077+
*
1078+
* @param function the function to call
1079+
* @param inputs the inputs to the function
1080+
* @return the outputs of the function
10761081
*/
10771082
public Map<String, Operand<?>> callConcreteFunction(ConcreteFunction function,
1078-
List<Operand<?>> inputs) {
1083+
Operand<?>... inputs) {
10791084
return Function.callConcreteFunction(scope, function, inputs);
10801085
}
10811086

10821087
/**
1083-
* empty
1088+
* Call {@code function}, adding it to the execution environment if it isn't already present. The inputs and outputs
1089+
* are keyed by the names set in the {@code ConcreteFunction}'s {@code Signature}.
1090+
*
1091+
* @param function the function to call
1092+
* @param inputs the inputs to the function
1093+
* @return the outputs of the function
10841094
*/
10851095
public Map<String, Operand<?>> callConcreteFunction(ConcreteFunction function,
1086-
Map<String, Operand<?>> inputs) {
1096+
List<Operand<?>> inputs) {
10871097
return Function.callConcreteFunction(scope, function, inputs);
10881098
}
10891099

10901100
/**
1091-
* empty
1101+
* Call {@code function}, adding it to the execution environment if it isn't already present. The inputs and outputs
1102+
* are keyed by the names set in the {@code ConcreteFunction}'s {@code Signature}.
1103+
*
1104+
* @param function the function to call
1105+
* @param inputs the inputs to the function
1106+
* @return the outputs of the function
10921107
*/
10931108
public Map<String, Operand<?>> callConcreteFunction(ConcreteFunction function,
1094-
Operand<?>... inputs) {
1109+
Map<String, Operand<?>> inputs) {
10951110
return Function.callConcreteFunction(scope, function, inputs);
10961111
}
10971112

10981113
/**
1099-
* empty
1114+
* Call {@code function}, adding it to the execution environment if it isn't already present.
1115+
*
1116+
* @param function the function to call
1117+
* @param inputs the inputs to the function
1118+
* @return the outputs of the function
11001119
*/
11011120
public List<Operand<?>> callFunction(GraphFunction function, List<Operand<?>> inputs) {
11021121
return Function.callFunction(scope, function, inputs);
11031122
}
11041123

11051124
/**
1106-
* empty
1125+
* Call {@code function}, adding it to the execution environment if it isn't already present.
1126+
*
1127+
* @param function the function to call
1128+
* @param inputs the inputs to the function
1129+
* @return the outputs of the function
11071130
*/
11081131
public List<Operand<?>> callFunction(GraphFunction function, Operand<?>... inputs) {
11091132
return Function.callFunction(scope, function, inputs);
11101133
}
11111134

11121135
/**
1113-
* empty
1136+
* Call {@code function}, adding it to the execution environment if it isn't already present.
1137+
*
1138+
* @param function the function to call
1139+
* @param inputs the inputs to the function
1140+
* @return the outputs of the function
11141141
*/
11151142
public Map<String, Operand<?>> callNamedFunction(NamedGraphFunction function,
1116-
List<Operand<?>> inputs) {
1143+
Map<String, Operand<?>> inputs) {
11171144
return Function.callNamedFunction(scope, function, inputs);
11181145
}
11191146

11201147
/**
1121-
* empty
1148+
* Call {@code function}, adding it to the execution environment if it isn't already present.
1149+
*
1150+
* @param function the function to call
1151+
* @param inputs the inputs to the function
1152+
* @return the outputs of the function
11221153
*/
11231154
public Map<String, Operand<?>> callNamedFunction(NamedGraphFunction function,
1124-
Map<String, Operand<?>> inputs) {
1155+
Operand<?>... inputs) {
11251156
return Function.callNamedFunction(scope, function, inputs);
11261157
}
11271158

11281159
/**
1129-
* empty
1160+
* Call {@code function}, adding it to the execution environment if it isn't already present.
1161+
*
1162+
* @param function the function to call
1163+
* @param inputs the inputs to the function
1164+
* @return the outputs of the function
11301165
*/
11311166
public Map<String, Operand<?>> callNamedFunction(NamedGraphFunction function,
1132-
Operand<?>... inputs) {
1167+
List<Operand<?>> inputs) {
11331168
return Function.callNamedFunction(scope, function, inputs);
11341169
}
11351170

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphFunction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ public class GraphFunction implements AutoCloseable {
6363
}
6464

6565
/**
66-
* This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA JIT is
67-
* extremely non-obvious.
66+
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because how to enable XLA
67+
* JIT is extremely non-obvious.
6868
*
6969
* Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered platform with id:
7070
* 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,36 @@
3333
*/
3434
public final class Output<T extends TType> implements Operand<T> {
3535

36-
/** Returns the index into the outputs of the Operation. */
36+
/**
37+
* Returns the index into the outputs of the Operation.
38+
*/
3739
public int index() {
3840
return index;
3941
}
4042

41-
/** Returns the DataType of the tensor referred to by this Output. */
43+
/**
44+
* Returns the DataType of the tensor referred to by this Output.
45+
*/
4246
@SuppressWarnings("unchecked")
4347
public DataType dataType() {
4448
return operation.dtype(index);
4549
}
4650

47-
/** Returns the type of the tensor referred to by this Output. */
51+
/**
52+
* Returns the type of the tensor referred to by this Output.
53+
*/
4854
@SuppressWarnings("unchecked")
4955
@Override
5056
public Class<T> type() {
51-
return (Class<T>)TensorTypeRegistry.find(dataType()).type();
57+
return (Class<T>) TensorTypeRegistry.find(dataType()).type();
5258
}
5359

5460
/**
55-
* Returns this Output object with the type {@code Output<U>}. This method is useful when given a
56-
* value of type {@code Output<?>}.
61+
* Returns this Output object with the type {@code Output<U>}. This method is useful when given a value of type {@code
62+
* Output<?>}.
5763
*
5864
* @param type any supported tensor type
59-
* @throws IllegalArgumentException if the actual data type of this object does not match the type
60-
* {@code U}.
65+
* @throws IllegalArgumentException if the actual data type of this object does not match the type {@code U}.
6166
*/
6267
@SuppressWarnings("unchecked")
6368
public <U extends TType> Output<U> expect(Class<U> type) {
@@ -72,8 +77,7 @@ public <U extends TType> Output<U> expect(Class<U> type) {
7277
* Returns the tensor at this output.
7378
*
7479
* <p>This operation is only supported on the outputs of an operation executed eagerly. For graph
75-
* environments, output tensors must be fetched by running a session, using {@link
76-
* Session.Runner#fetch(Output)}.
80+
* environments, output tensors must be fetched by running a session, using {@link Session.Runner#fetch(Output)}.
7781
*
7882
* <p>It is recommended to close explicitly the returned tensor as soon as possible, since the
7983
* garbage collector is not aware of the amount of memory it consumes, which can be significant.
@@ -85,7 +89,7 @@ public <U extends TType> Output<U> expect(Class<U> type) {
8589
*/
8690
@SuppressWarnings("unchecked")
8791
public T asTensor() {
88-
return (T)operation.tensor(index);
92+
return (T) operation.tensor(index);
8993
}
9094

9195
/**
@@ -130,8 +134,20 @@ public String toString() {
130134
operation.type(), operation.name(), index, shape().toString(), dataType());
131135
}
132136

133-
/** Handle to the idx-th output of the Operation {@code op}. */
137+
/**
138+
* Handle to the idx-th output of the Operation {@code op}.
139+
*/
134140
Output(AbstractOperation op, int idx) {
141+
int numOutputs = op.numOutputs();
142+
if (idx >= numOutputs) {
143+
throw new IndexOutOfBoundsException(
144+
"Can't get output with index " + idx + ", this op only has " + numOutputs + " outputs.");
145+
}
146+
147+
if (idx < 0) {
148+
throw new IndexOutOfBoundsException("Can't get output with index < 0.");
149+
}
150+
135151
operation = op;
136152
index = idx;
137153
}

0 commit comments

Comments
 (0)