diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java new file mode 100644 index 00000000000..740338350e3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression, + * regularization penalty. + * + *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) + * + * @param the data type for the weights + */ +public class L1 extends L1L2 { + + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. + */ + public L1(Ops tf, float l1, Class type) { + super(tf, l1, null, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java new file mode 100644 index 00000000000..2908387b9d4 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + *

The difference between this class and the {@link L1_L2} is use of the default regularization + * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. + * + * @param the data type for the weights + */ +public class L1L2 extends Regularizer { + + private final Float l1; + private final Float l2; + + /** + * Creates an L1L2 regularizer with no l1 or l2 penalty with default penal + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights + */ + public L1L2(Ops tf, Class type) { + this(tf, null, null, type); + } + + /** + * Creates an L1L2 regularizer + * + * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @param type the data type for the weights + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} + */ + public L1L2(Ops tf, Float l1, Float l2, Class type) { + super(tf, type); + if (l1 != null) { + if (l1.isNaN() || l1.isInfinite()) { + throw new IllegalArgumentException( + String.format( + "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l1)); + } + this.l1 = l1; + } else { + this.l1 = 0f; + } + if (l2 != null) { + if (l2.isNaN() || l2.isInfinite()) { + throw new IllegalArgumentException( + String.format( + "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l2)); + } + this.l2 = l2; + } else { + this.l2 = 0f; + } + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + Ops tf = getTF(); + if (this.getL1() == null && this.getL2() == null) { + return tf.dtypes.cast(tf.constant(0), input.type()); + } + Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + + if (this.getL1() != null && this.getL1() != 0.f) { + Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand abs = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); + } + + if (this.getL2() != null && this.getL2() != 0.f) { + Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand sqr = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); + } + + return regularization; + } + + /** + * Gets the L1 regularization factor + * + * @return the L1 regularization factor + */ + public Float getL1() { + return l1; + } + + /** + * Gets the L2 regularization factor + * + * @return the L2 regularization factor + */ + public Float getL2() { + return l2; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java new file mode 100644 index 00000000000..95eecc2dd5f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + *

The difference between this class and the {@link L1L2} is use of the default regularization + * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. + * + * @param the data type for the weights + */ +public class L1_L2 extends L1L2 { + + /** + * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1_L2(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Creates a regularizer that applies an L1 and l2 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. + */ + public L1_L2(Ops tf, Float l1, Float l2, Class type) { + super(tf, + l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, + l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2, + type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java new file mode 100644 index 00000000000..8298cd4aba5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies a L2 (Ridge Regression) regularization penalty. + * + *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) + * + * @param the data type for the operands and result + */ +public class L2 extends L1L2 { + + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L2(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. + */ + public L2(Ops tf, float l2, Class type) { + super(tf, null, l2, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java new file mode 100644 index 00000000000..906efee7f3d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Base class for Regularizers + * + *

Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + * + * @param the data type of the operands and result + */ +public abstract class Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final Ops tf; + private final String name; + protected Class type; + + /** + * Creates a Regularizer + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf, Class type) { + this(tf, null, type); + } + /** + * Creates a Regularizer + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf, String name, Class type) { + this.tf = tf; + this.type = type; + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only + * sampleWeights are applied to the regularizer. + * + * @return this Regularizer as a Loss + */ + public Loss asLoss() { + return new RegularizerLoss<>(this.tf, this); + } + + /** + * Computes a regularization penalty from an input. + * + * @param input the weighted input + * @return the result of computing the regularization penalty + */ + public abstract Operand call(Operand input); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java new file mode 100644 index 00000000000..04414285d77 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * A Regularizer call wrapped as a Loss instance + * + *

This class facilitates using a regularizer as a loss, only sampleWeights are + * regularized. + * + * @param the datatype for the weights type + */ +class RegularizerLoss extends Loss { + + private final Regularizer regularizer; + private final Class type; + /** + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + */ + public RegularizerLoss(Ops tf, Regularizer regularizer) { + this(tf, null, regularizer); + } + + /** + * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + */ + public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { + super(tf, name); + this.regularizer = regularizer; + this.type = regularizer.type; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + if (sampleWeights == null) { + throw new IllegalArgumentException("sampleWeights cannot be null"); + } + Operand result = regularizer.call(cast(getTF(), sampleWeights, type)); + return cast(tf, result, sampleWeights.type()); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java new file mode 100644 index 00000000000..63ecc155fd1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.utils.ND; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.StdArrays; + +public class CommonTest { + + protected float regularizeL1L2(float[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected float regularizeL1(float[][] w, float l1) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l1); + return mul.getFloat(); + } + + protected float regularizeL2(float[][] w, float l2) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l2); + return mul.getFloat(); + } + + protected double regularizeL1L2(double[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected double regularizeL1(double[][] w, float l1) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l1); + return mul.getDouble(); + } + + protected double regularizeL2(double[][] w, float l2) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l2); + return mul.getDouble(); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java new file mode 100644 index 00000000000..0f3213ed6eb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -0,0 +1,120 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.2f, 0.3f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1L2<>(tf, null, null, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2<>(tf, 0.5f, null, TFloat32.class); + assertEquals(0.5f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2<>(tf, null, 0.5f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + } + } + + @Test + public void testCallDefaultsConstant() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, TFloat32.class); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(0f, result); + } + } + + @Test + public void testCallL1L20() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, TFloat32.class); + Operand weights = + tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + @Test + public void testCallL1L2TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1L2TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL2Null() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, null, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1Null() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, null, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java new file mode 100644 index 00000000000..6d67bb44d3c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -0,0 +1,74 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(0.f, instance.getL2()); + } + } + + @Test + public void testCallL10() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0f, result); + } + } + + @Test + public void testCallL1TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1(w, 0.02f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java new file mode 100644 index 00000000000..5aeb5a5d9ad --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -0,0 +1,130 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1_L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.2f, 0.3f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1_L2<>(tf, 0.5f, 0f, TFloat32.class); + assertEquals(0.5f, instance.getL1()); + assertEquals(0f, instance.getL2()); + + instance = new L1_L2<>(tf, 0f, 0.5f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + + instance = new L1_L2<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + } + } + + @Test + public void testCallZero() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0f, 0f, TFloat32.class); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(0, result); + } + } + + @Test + public void testCallDefaultTFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = + regularizeL1L2( + w, + Regularizer.DEFAULT_REGULARIZATION_PENALTY, + Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallDefaultTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, TFloat64.class); + double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = + regularizeL1L2( + w, + Regularizer.DEFAULT_REGULARIZATION_PENALTY, + Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1L2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.01f, 0.02f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL20() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.01f, 0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL10() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0f, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java new file mode 100644 index 00000000000..7f593a2dd14 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -0,0 +1,76 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + instance = new L2<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + L2 instance64 = new L2<>(tf, TFloat64.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + assertEquals(0.f, instance64.getL1()); + } + } + + @Test + public void testCallL20() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + @Test + public void testCallL2TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL2TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java new file mode 100644 index 00000000000..e694d9409a0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -0,0 +1,7 @@ +package org.tensorflow.framework.regularizers; + +import static org.junit.jupiter.api.Assertions.*; + +class RegularizerLossTest { + +} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 0503a41dfc2..694287d4970 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,10 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -120,6 +117,23 @@ public static FloatNdArray square(FloatNdArray a) { return result; } + /** + * Gets the square of an array. + * + * @param a the array + * @return the square of the array. + */ + public static DoubleNdArray square(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); + return result; + } + /** * Adds two arrays * @@ -284,6 +298,64 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { return mul(a, scalar); } + /** + * Multiply 2 arrays + * + * @param a the first array + * @param b the second array + * @return the resulting array from the muliply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) { + if (!a.shape().equals(b.shape())) + throw new IllegalArgumentException( + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + boolean sameSize = a.shape().size() == b.shape().size(); + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + if (sameSize) { + result.setDouble(v.getDouble() * b.getDouble(idx), idx); + } else { + double value = v.getDouble() * b.getDouble(idx[0], 0L); + result.setDouble(value, idx); + } + }); + return result; + } + + /** + * Multiply an array with a scalar value + * + * @param a the array + * @param scalar the scalar value + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, float scalar) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + if (a.shape().isScalar()) { + a.scalars().forEach(f -> result.setDouble(f.getDouble() * scalar)); + } else { + a.scalars().forEachIndexed((idx, f) -> result.setDouble(f.getDouble() * scalar, idx)); + } + + return result; + } + + /** + * Multiply a scalar value with an array + * + * @param scalar the scalar value + * @param a the array + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(float scalar, DoubleNdArray a) { + return mul(a, scalar); + } + /** * Divide two arrays * @@ -556,6 +628,18 @@ public static FloatNdArray abs(FloatNdArray a) { return result; } + /** + * Get the absolute value of each member of the array + * + * @param a the array + * @return the array with the absolute value of each item. + */ + public static DoubleNdArray abs(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + a.scalars().forEachIndexed((idx, f) -> result.setDouble( Math.abs(f.getDouble()), idx)); + return result; + } + /** * Sum all elements of an array * @@ -647,6 +731,97 @@ public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) } } + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference<>(0.); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { + Shape shape = a.shape(); + int nDims = shape.numDimensions(); + int xis = nDims - 1 - axis; + long totalSize = shape.size(); + long axisSize = shape.size(xis); + final double[] sums = new double[(int) axisSize]; + + a.scalars() + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); + + if (keepDims) { + long[] newDims = shape.asArray(); + newDims[axis] = 1; + final AtomicInteger counter = new AtomicInteger(); + DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); + arrayK + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); + return arrayK; + } else { + return NdArrays.vectorOf(sums); + } + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axes the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, Integer[] axes, boolean keepDims) { + Shape shape = a.shape(); + if (axes == null) { + DoubleNdArray result = sum(a); + if (keepDims) { + double scalar = result.getDouble(0); + long[] dims = {1, 1}; + Shape bShape = Shape.of(dims); + DoubleNdArray resultK = NdArrays.ofDoubles(bShape); + resultK.setDouble(scalar, 0, 0); + return resultK; + } + return result; + } else if (axes.length == 1) { + return sum(a, axes[0], keepDims); + } else { + // TODO + throw new UnsupportedOperationException("Multi Axis Not implemented Yet"); + } + } + /** * Calculate the l2 norm of the array *