Skip to content

Commit ef29af9

Browse files
committed
Changed to TFloating where appropriate.
Misc fixes to JavaDoc. In ReLU, change to assign to new variable 'lInput' rather than change the 'input' parameter.
1 parent 4c44c62 commit ef29af9

File tree

11 files changed

+39
-83
lines changed

11 files changed

+39
-83
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import org.tensorflow.Operand;
1919
import org.tensorflow.op.Ops;
2020
import org.tensorflow.types.TBool;
21-
import org.tensorflow.types.family.TNumber;
21+
import org.tensorflow.types.family.TFloating;
2222

2323
/**
2424
* Exponential linear unit.
@@ -49,7 +49,7 @@
4949
* @see <a href="https://arxiv.org/abs/1511.07289">Clevert et al, 2016, Fast and Accurate Deep
5050
* Network Learning by Exponential Linear Units (ELUs)</a>
5151
*/
52-
public class ELU<T extends TNumber> extends Activation<T> {
52+
public class ELU<T extends TFloating> extends Activation<T> {
5353

5454
private static final double ALPHA_DEFAULT = 1.0;
5555

@@ -82,18 +82,13 @@ public ELU(Ops tf, double alpha) {
8282
*
8383
* @param input the input tensor
8484
* @return The operand for the activation
85-
* @throws IllegalArgumentException if the data type is not a floating data type.
8685
*/
8786
@Override
8887
public Operand<T> call(Operand<T> input) {
89-
if (!input.asOutput().dataType().isFloating()) {
90-
throw new IllegalArgumentException(
91-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
92-
}
88+
9389
Operand<T> result = tf.nn.elu(input);
94-
if (alpha == 1.0) {
95-
return result;
96-
} else {
90+
if (alpha == 1.0) return result;
91+
else {
9792
DataType<T> dataType = input.asOutput().dataType();
9893
Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType));
9994
Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType));

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Exponential activation function.
@@ -33,7 +33,7 @@
3333
*
3434
* @param <T> the data type of the activation
3535
*/
36-
public class Exponential<T extends TNumber> extends Activation<T> {
36+
public class Exponential<T extends TFloating> extends Activation<T> {
3737

3838
/**
3939
* Creates an Exponential activation.
@@ -49,14 +49,9 @@ public Exponential(Ops tf) {
4949
*
5050
* @param input the input tensor
5151
* @return an Operand for the exponential activation: <code>exp(x)</code>.
52-
* @throws IllegalArgumentException if the input is not a floating type
5352
*/
5453
@Override
5554
public Operand<T> call(Operand<T> input) {
56-
if (!input.asOutput().dataType().isFloating()) {
57-
throw new IllegalArgumentException(
58-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
59-
}
6055
return tf.math.exp(input);
6156
}
6257
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.tensorflow.DataType;
1818
import org.tensorflow.Operand;
1919
import org.tensorflow.op.Ops;
20-
import org.tensorflow.types.family.TNumber;
20+
import org.tensorflow.types.family.TFloating;
2121

2222
/**
2323
* Hard sigmoid activation.
@@ -44,7 +44,7 @@
4444
*
4545
* @param <T> the data type of the result
4646
*/
47-
public class HardSigmoid<T extends TNumber> extends Activation<T> {
47+
public class HardSigmoid<T extends TFloating> extends Activation<T> {
4848

4949
/**
5050
* Creates Hard sigmoid activation.
@@ -60,14 +60,9 @@ public HardSigmoid(Ops tf) {
6060
*
6161
* @param input the input tensor
6262
* @return The operand for the activation
63-
* @throws IllegalArgumentException if the data type is not a floating data type.
6463
*/
6564
@Override
6665
public Operand<T> call(Operand<T> input) {
67-
if (!input.asOutput().dataType().isFloating()) {
68-
throw new IllegalArgumentException(
69-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
70-
}
7166
DataType<T> dataType = input.asOutput().dataType();
7267
Operand<T> point2 = tf.dtypes.cast(tf.constant(0.2), dataType);
7368
Operand<T> point5 = tf.dtypes.cast(tf.constant(0.5), dataType);

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,28 +117,29 @@ public Operand<T> call(Operand<T> input) {
117117
}
118118
}
119119

120+
Operand<T> lInput;
120121
if (threshold != 0) {
121122
// computes input for input > threshold else 0
122123
Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), dataType));
123-
input = tf.math.mul(input, tf.dtypes.cast(greater, dataType));
124+
lInput = tf.math.mul(input, tf.dtypes.cast(greater, dataType));
124125
} else if (maxValue == 6) {
125126
// if no threshold, then can use nn.relu6 native TF op for performance
126-
input = tf.nn.relu6(input);
127+
lInput = tf.nn.relu6(input);
127128
clipMax = false;
128129
} else {
129-
input = tf.nn.relu(input);
130+
lInput = tf.nn.relu(input);
130131
}
131132
if (clipMax) {
132133
Operand<T> lmaxValue = tf.dtypes.cast(tf.constant(maxValue), dataType);
133134
Operand<T> zero = tf.dtypes.cast(tf.constant(0), dataType);
134-
input = tf.clipByValue(input, zero, lmaxValue);
135+
lInput = tf.clipByValue(lInput, zero, lmaxValue);
135136
}
136137

137138
if (alpha != 0.) {
138-
input =
139+
lInput =
139140
tf.math.sub(
140-
input, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), dataType), negativePart));
141+
lInput, tf.math.mul(tf.dtypes.cast(tf.constant(alpha), dataType), negativePart));
141142
}
142-
return input;
143+
return lInput;
143144
}
144145
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Scaled Exponential Linear Unit (SELU).
@@ -31,8 +31,8 @@
3131
* <p>where <code>alpha</code> and <code>scale</code> are pre-defined constants (<code>
3232
* alpha=1.67326324</code> and <code>scale=1.05070098</code>).
3333
*
34-
* <p>Basically, the SELU activation function multiplies <code>scale</code> (&gt; 1) with the output of
35-
* the elu function to ensure a slope larger than one for positive inputs.
34+
* <p>Basically, the SELU activation function multiplies <code>scale</code> (&gt; 1) with the output
35+
* of the elu function to ensure a slope larger than one for positive inputs.
3636
*
3737
* <p>The values of <code>alpha</code> and <code>scale</code> are chosen so that the mean and
3838
* variance of the inputs are preserved between two consecutive layers as long as the weights are
@@ -45,7 +45,7 @@
4545
* @param <T> the data type of the activation
4646
* @see <a href="https://arxiv.org/abs/1706.02515">Klambauer et al., 2017</a>
4747
*/
48-
public class SELU<T extends TNumber> extends Activation<T> {
48+
public class SELU<T extends TFloating> extends Activation<T> {
4949

5050
/**
5151
* Creates a Scaled Exponential Linear Unit (SELU) activation.
@@ -61,14 +61,9 @@ public SELU(Ops tf) {
6161
*
6262
* @param input the input tensor
6363
* @return The operand for the activation
64-
* @throws IllegalArgumentException if the data type is not a floating data type.
6564
*/
6665
@Override
6766
public Operand<T> call(Operand<T> input) {
68-
if (!input.asOutput().dataType().isFloating()) {
69-
throw new IllegalArgumentException(
70-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
71-
}
7267
return tf.nn.selu(input);
7368
}
7469
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Sigmoid activation. <code>sigmoid(x) = 1 / (1 + exp(-x))</code>.
2323
*
24-
* <p>Applies the sigmoid activation function. For small values (&lt;-5), <code>sigmoid</code> returns
25-
* a value close to zero, and for large values (&gt;5) the result of the function gets close to 1.
24+
* <p>Applies the sigmoid activation function. For small values (&lt;-5), <code>sigmoid</code>
25+
* returns a value close to zero, and for large values (&gt;5) the result of the function gets close
26+
* to 1.
2627
*
2728
* <p>Sigmoid is equivalent to a 2-element Softmax, where the second element is assumed to be zero.
2829
* The sigmoid function always returns a value between 0 and 1.
@@ -40,7 +41,7 @@
4041
*
4142
* @param <T> the data type of the activation
4243
*/
43-
public class Sigmoid<T extends TNumber> extends Activation<T> {
44+
public class Sigmoid<T extends TFloating> extends Activation<T> {
4445

4546
/**
4647
* Creates a Sigmoid activation.
@@ -56,14 +57,9 @@ public Sigmoid(Ops tf) {
5657
*
5758
* @param input the input tensor
5859
* @return The operand for the activation
59-
* @throws IllegalArgumentException if the data type is not a floating data type.
6060
*/
6161
@Override
6262
public Operand<T> call(Operand<T> input) {
63-
if (!input.asOutput().dataType().isFloating()) {
64-
throw new IllegalArgumentException(
65-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
66-
}
6763
return tf.math.sigmoid(input);
6864
}
6965
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import org.tensorflow.op.Ops;
2020
import org.tensorflow.op.core.ReduceMax;
2121
import org.tensorflow.op.core.ReduceSum;
22-
import org.tensorflow.types.family.TNumber;
22+
import org.tensorflow.types.family.TFloating;
2323

2424
/**
2525
* Softmax converts a real vector to a vector of categorical probabilities.
2626
*
2727
* <p>The elements of the output vector are in range (0, 1) and sum to 1.
2828
*
29-
* <p>Each vector is handled independently. The <code>axis</code>argument sets which axis of the
29+
* <p>Each vector is handled independently. The <code>axis</code> argument sets which axis of the
3030
* input the function is applied along.
3131
*
3232
* <p>Softmax is often used as the activation for the last layer of a classification network because
@@ -38,7 +38,7 @@
3838
*
3939
* @param <T> the data type of the activation
4040
*/
41-
public class Softmax<T extends TNumber> extends Activation<T> {
41+
public class Softmax<T extends TFloating> extends Activation<T> {
4242

4343
private static final int AXIS_DEFAULT = -1;
4444

@@ -73,10 +73,6 @@ public Softmax(Ops tf, int axis) {
7373
*/
7474
@Override
7575
public Operand<T> call(Operand<T> input) {
76-
if (!input.asOutput().dataType().isFloating()) {
77-
throw new IllegalArgumentException(
78-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
79-
}
8076
Shape shape = input.asOutput().shape();
8177
int numDimensions = shape.numDimensions();
8278
if (numDimensions == 2) {

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Softplus activation function, <code>softplus(x) = log(exp(x) + 1)</code>.
@@ -32,7 +32,7 @@
3232
* // 1.3132616e+00f, 2.0000000e+01f]
3333
* </pre>
3434
*/
35-
public class Softplus<T extends TNumber> extends Activation<T> {
35+
public class Softplus<T extends TFloating> extends Activation<T> {
3636

3737
/**
3838
* Creates a Softplus activation function.
@@ -48,14 +48,9 @@ public Softplus(Ops tf) {
4848
*
4949
* @param input the input tensor
5050
* @return The operand for the activation
51-
* @throws IllegalArgumentException if the data type is not a floating data type.
5251
*/
5352
@Override
5453
public Operand<T> call(Operand<T> input) {
55-
if (!input.asOutput().dataType().isFloating()) {
56-
throw new IllegalArgumentException(
57-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
58-
}
5954
return tf.math.softplus(input);
6055
}
6156
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Softsign activation function, <code>softsign(x) = x / (abs(x) + 1)</code>.
@@ -33,7 +33,7 @@
3333
*
3434
* @param <T> the data type of the activation
3535
*/
36-
public class Softsign<T extends TNumber> extends Activation<T> {
36+
public class Softsign<T extends TFloating> extends Activation<T> {
3737

3838
/**
3939
* Creates a Softsign activation.
@@ -49,14 +49,9 @@ public Softsign(Ops tf) {
4949
*
5050
* @param input the input tensor
5151
* @return The operand for the activation
52-
* @throws IllegalArgumentException if the data type is not a floating data type.
5352
*/
5453
@Override
5554
public Operand<T> call(Operand<T> input) {
56-
if (!input.asOutput().dataType().isFloating()) {
57-
throw new IllegalArgumentException(
58-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
59-
}
6055
return tf.nn.softsign(input);
6156
}
6257
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Swish activation function. <code>swish(x) = x * sigmoid(x)</code>.
@@ -40,7 +40,7 @@
4040
* @param <T> the data type of the activation
4141
* @see <a href="https://arxiv.org/abs/1710.05941">Ramachandran et al., 2017</a>
4242
*/
43-
public class Swish<T extends TNumber> extends Activation<T> {
43+
public class Swish<T extends TFloating> extends Activation<T> {
4444

4545
/**
4646
* Creates a Swish activation, <code>swish(x) = x * sigmoid(x)</code>.
@@ -58,11 +58,8 @@ public Swish(Ops tf) {
5858
/** {@inheritDoc} */
5959
@Override
6060
public Operand<T> call(Operand<T> input) {
61-
if (!input.asOutput().dataType().isFloating()) {
62-
throw new IllegalArgumentException(
63-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
64-
}
65-
// TODO Python Keras returns a "grad", which is an optimization not implmented in Java.
61+
62+
// TODO Python Keras returns a "grad", which is an optimization not implemented in Java.
6663
return tf.math.mul(input, tf.math.sigmoid(input));
6764
}
6865
}

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import org.tensorflow.Operand;
1818
import org.tensorflow.op.Ops;
19-
import org.tensorflow.types.family.TNumber;
19+
import org.tensorflow.types.family.TFloating;
2020

2121
/**
2222
* Hyperbolic tangent activation function.
@@ -33,7 +33,7 @@
3333
*
3434
* @param <T> the data type of the activation
3535
*/
36-
public class Tanh<T extends TNumber> extends Activation<T> {
36+
public class Tanh<T extends TFloating> extends Activation<T> {
3737

3838
/**
3939
* Creates a Hyperbolic tangent activation.
@@ -47,10 +47,6 @@ public Tanh(Ops tf) {
4747
/** {@inheritDoc} */
4848
@Override
4949
public Operand<T> call(Operand<T> input) {
50-
if (!input.asOutput().dataType().isFloating()) {
51-
throw new IllegalArgumentException(
52-
"Must be a Floating Point DataType: " + input.asOutput().dataType());
53-
}
5450
return tf.math.tanh(input);
5551
}
5652
}

0 commit comments

Comments
 (0)